3 Commits

Author SHA1 Message Date
56e72928c6 fix use after free 2025-05-10 21:46:53 -04:00
a80c9abfe7 Attempt to base64 encode the connection payload
For some reason I am still getting this:

2025/05/10 16:37:06 Error decoding message: SGVsbG8gZGFya25lc3MgbXkgb2xkIGZyaWVuZA==::53475673624738675a4746796132356c63334d6762586b676232786b49475a79615756755a413d3daaaa
2025-05-10 21:46:53 -04:00
245dab4909 Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
2025-05-10 21:46:53 -04:00
2 changed files with 96 additions and 166 deletions

View File

@@ -48,15 +48,22 @@ fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void {
} }
pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void { pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void {
const msg: *SaprusMessageNew(.relay) = try .init( const msg_bytes = try allocator.alignedAlloc(
allocator, u8,
@intCast(base64Enc.calcSize(payload.len)), @alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.relay,
base64Enc.calcSize(payload.len),
),
); );
defer msg.deinit(allocator); defer allocator.free(msg_bytes);
msg.dest = dest; const msg: *SaprusMessage = .init(.relay, msg_bytes);
_ = base64Enc.encode(msg.getPayload(), payload);
try broadcastSaprusMessage(try SaprusMessage.bytesAsValue(msg.asBytes()), 8888); const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = dest;
_ = base64Enc.encode(relay.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
} }
fn randomPort() u16 { fn randomPort() u16 {
@@ -70,12 +77,21 @@ fn randomPort() u16 {
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage { pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage {
const dest_port = randomPort(); const dest_port = randomPort();
const msg: *SaprusMessage = try .init(allocator, .connection, @intCast(payload.len)); const msg_bytes = try allocator.alignedAlloc(
defer msg.deinit(allocator); u8,
@alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.connection,
base64Enc.calcSize(payload.len),
),
);
const msg: *SaprusMessage = .init(.connection, msg_bytes);
const connection = (try msg.getSaprusTypePayload()).connection; const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = initial_port; connection.src_port = initial_port;
connection.dest_port = dest_port; connection.dest_port = dest_port;
@memcpy(connection.getPayload(), payload); _ = base64Enc.encode(connection.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888); try broadcastSaprusMessage(msg, 8888);
@@ -88,8 +104,7 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
initial_port = r.intRangeAtMost(u16, 1024, 65000); initial_port = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable; } else unreachable;
var initial_conn_res: ?SaprusMessage = null; var initial_conn_res: ?*SaprusMessage = null;
errdefer if (initial_conn_res) |*c| c.deinit(allocator);
var sock = try network.Socket.create(.ipv4, .udp); var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close(); defer sock.close();
@@ -105,13 +120,14 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
try sock.bind(bind_addr); try sock.bind(bind_addr);
const msg = try sendInitialConnection(payload, initial_port, allocator); const msg = try sendInitialConnection(payload, initial_port, allocator);
defer allocator.free(msg.asBytes());
var response_buf: [4096]u8 align(4) = @splat(0); var response_buf: [4096]u8 align(@alignOf(SaprusMessage)) = undefined;
_ = try sock.receive(&response_buf); // Ignore message that I sent. _ = try sock.receive(&response_buf); // Ignore message that I sent.
const len = try sock.receive(&response_buf); const len = try sock.receive(&response_buf);
std.debug.print("response bytes: {x}\n", .{response_buf}); std.debug.print("response bytes: {x}\n", .{response_buf[0..len]});
initial_conn_res = (try SaprusMessage.bytesAsValue(response_buf[0..len])).*; initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]);
// Complete handshake after awaiting response // Complete handshake after awaiting response
try broadcastSaprusMessage(msg, randomPort()); try broadcastSaprusMessage(msg, randomPort());
@@ -123,7 +139,6 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
} }
const SaprusMessage = @import("message.zig").Message; const SaprusMessage = @import("message.zig").Message;
const SaprusMessageNew = @import("message.zig").MessageNew;
const SaprusConnection = @import("Connection.zig"); const SaprusConnection = @import("Connection.zig");
const std = @import("std"); const std = @import("std");

View File

@@ -20,130 +20,14 @@ pub const ConnectionOptions = packed struct(u8) {
opt8: bool = false, opt8: bool = false,
}; };
pub const Error = error{ pub const MessageTypeError = error{
NotImplementedSaprusType, NotImplementedSaprusType,
UnknownSaprusType, UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage, InvalidMessage,
}; };
pub fn MessageNew(comptime packet_type: PacketType) type {
comptime {
if (packet_type == .file_transfer)
@compileError("File transfer not implemented");
if (packet_type != .relay and packet_type != .connection)
@compileError("Unkown message type");
}
return packed struct {
const Self = @This();
const SelfBytes = []align(@alignOf(Self)) u8;
const Relay = struct {
pub fn getPayload(self: *Self) []u8 {
return @as([*]align(@alignOf(Self)) u8, @ptrCast(&self.payload))[0 .. self.length - 4];
}
};
const Connection = packed struct {
pub fn getPayload(self: Self) []u8 {
return @as([*]u8, &self.payload)[0 .. self.length - 4];
}
};
type: PacketType = packet_type,
length: u16,
// Relay
dest: if (packet_type == .relay) @Vector(4, u8) else void,
// Connection
src_port: if (packet_type == .connection) u16 else void, // random number > 1024
dest_port: if (packet_type == .connection) u16 else void, // random number > 1024
seq_num: if (packet_type == .connection) u32 else void,
msg_id: if (packet_type == .connection) u32 else void,
reserved: if (packet_type == .connection) u8 else void,
options: if (packet_type == .connection) ConnectionOptions else void = if (packet_type == .connection) .{} else {},
// Relay or Connection
payload: switch (packet_type) {
.relay, .connection => void,
else => noreturn,
},
pub usingnamespace switch (packet_type) {
.relay => Relay,
.connection => Connection,
.file_transfer => @compileError("File Transfer message type not implemented"),
else => @compileError("Unknown message type"),
};
pub fn init(allocator: Allocator, payload_len: u16) !*Self {
const size = payload_len + @sizeOf(Self);
const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size);
const res: *Self = @ptrCast(bytes.ptr);
res.type = packet_type;
res.length = payload_len;
return res;
}
pub fn deinit(self: *Self, allocator: Allocator) void {
allocator.free(self.asBytes());
}
pub fn nativeFromNetworkEndian(self: *Self) void {
self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = bigToNative(@TypeOf(self.length), self.length);
if (packet_type == .connection) {
self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port);
self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id);
}
}
pub fn networkFromNativeEndian(self: *Self) void {
self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = bigToNative(@TypeOf(self.length), self.length);
if (packet_type == .connection) {
self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port);
self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = nativeToBig(@TypeOf(self.msg_id), self.msg_id);
}
}
pub fn asBytes(self: *Self) SelfBytes {
const size = @sizeOf(Self) + self.length;
return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
}
};
}
test MessageNew {
comptime for (@typeInfo(MessageNew(.connection)).@"struct".decls) |field| {
@compileLog(field);
};
}
// pub fn bytesAsMessage(bytes: []const u8) !*Self {
// const res = std.mem.bytesAsValue(Self, bytes);
// return switch (res.type) {
// .relay, .connection => if (bytes.len == res.length + @sizeOf(Self))
// res
// else
// Error.InvalidMessage,
// .file_transfer => Error.NotImplementedSaprusType,
// else => Error.UnknownSaprusType,
// };
// }
// ZERO COPY STUFF // ZERO COPY STUFF
// &payload could be a void value that is treated as a pointer to a [*]u8 // &payload could be a void value that is treated as a pointer to a [*]u8
/// All Saprus messages /// All Saprus messages
@@ -193,23 +77,25 @@ pub const Message = packed struct {
length: u16, length: u16,
bytes: void = {}, bytes: void = {},
pub fn init(allocator: Allocator, comptime @"type": PacketType, payload_len: u16) !*Self { /// Takes a byte slice, and returns a Message struct backed by the slice.
const header_size = @sizeOf(switch (@"type") { /// This properly initializes the top level headers within the slice.
.relay => Relay, pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self {
.connection => Connection, std.debug.assert(bytes.len >= @sizeOf(Self));
.file_transfer => return Error.NotImplementedSaprusType,
else => return Error.UnknownSaprusType,
});
const size = payload_len + @sizeOf(Self) + header_size;
const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size);
const res: *Self = @ptrCast(bytes.ptr); const res: *Self = @ptrCast(bytes.ptr);
res.type = @"type"; res.type = @"type";
res.length = payload_len + header_size; res.length = @intCast(bytes.len - @sizeOf(Self));
return res; return res;
} }
pub fn deinit(self: *Self, allocator: Allocator) void { pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
allocator.free(self.asBytes()); std.debug.assert(payload_len < std.math.maxInt(u16));
const header_size = @sizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType,
});
return @intCast(payload_len + @sizeOf(Self) + header_size);
} }
fn getRelay(self: *Self) *align(1) Relay { fn getRelay(self: *Self) *align(1) Relay {
@@ -219,7 +105,7 @@ pub const Message = packed struct {
return std.mem.bytesAsValue(Connection, &self.bytes); return std.mem.bytesAsValue(Connection, &self.bytes);
} }
pub fn getSaprusTypePayload(self: *Self) Error!(union(PacketType) { pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {
relay: *align(1) Relay, relay: *align(1) Relay,
file_transfer: void, file_transfer: void,
connection: *align(1) Connection, connection: *align(1) Connection,
@@ -227,12 +113,12 @@ pub const Message = packed struct {
return switch (self.type) { return switch (self.type) {
.relay => .{ .relay = self.getRelay() }, .relay => .{ .relay = self.getRelay() },
.connection => .{ .connection = self.getConnection() }, .connection => .{ .connection = self.getConnection() },
.file_transfer => Error.NotImplementedSaprusType, .file_transfer => MessageTypeError.NotImplementedSaprusType,
else => Error.UnknownSaprusType, else => MessageTypeError.UnknownSaprusType,
}; };
} }
pub fn nativeFromNetworkEndian(self: *Self) Error!void { pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {
self.type = @enumFromInt(bigToNative( self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type, @typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type), @intFromEnum(self.type),
@@ -255,12 +141,12 @@ pub const Message = packed struct {
} }
} }
pub fn networkFromNativeEndian(self: *Self) Error!void { pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {
try switch (try self.getSaprusTypePayload()) { try switch (try self.getSaprusTypePayload()) {
.relay => {}, .relay => {},
.connection => |*con| con.*.networkFromNativeEndian(), .connection => |*con| con.*.networkFromNativeEndian(),
.file_transfer => Error.NotImplementedSaprusType, .file_transfer => MessageTypeError.NotImplementedSaprusType,
else => Error.UnknownSaprusType, else => MessageTypeError.UnknownSaprusType,
}; };
self.type = @enumFromInt(nativeToBig( self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type, @typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@@ -269,18 +155,8 @@ pub const Message = packed struct {
self.length = nativeToBig(@TypeOf(self.length), self.length); self.length = nativeToBig(@TypeOf(self.length), self.length);
} }
pub fn bytesAsValue(bytes: SelfBytes) !*Self { /// Deprecated.
const res = std.mem.bytesAsValue(Self, bytes); /// If I need the bytes, I should just pass around the slice that is backing this to begin with.
return switch (res.type) {
.relay, .connection => if (bytes.len == res.length + @sizeOf(Self))
res
else
Error.InvalidMessage,
.file_transfer => Error.NotImplementedSaprusType,
else => Error.UnknownSaprusType,
};
}
pub fn asBytes(self: *Self) SelfBytes { pub fn asBytes(self: *Self) SelfBytes {
const size = @sizeOf(Self) + self.length; const size = @sizeOf(Self) + self.length;
return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size]; return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
@@ -334,6 +210,45 @@ const asBytes = std.mem.asBytes;
const nativeToBig = std.mem.nativeToBig; const nativeToBig = std.mem.nativeToBig;
const bigToNative = std.mem.bigToNative; const bigToNative = std.mem.bigToNative;
test "Round trip Relay toBytes and fromBytes" {
const gpa = std.testing.allocator;
const msg = Message{
.relay = .{
.header = .{ .dest = .{ 255, 255, 255, 255 } },
.payload = "Hello darkness my old friend",
},
};
const to_bytes = try msg.toBytes(gpa);
defer gpa.free(to_bytes);
const from_bytes = try Message.fromBytes(to_bytes, gpa);
defer from_bytes.deinit(gpa);
try std.testing.expectEqualDeep(msg, from_bytes);
}
test "Round trip Connection toBytes and fromBytes" {
const gpa = std.testing.allocator;
const msg = Message{
.connection = .{
.header = .{
.src_port = 0,
.dest_port = 0,
},
.payload = "Hello darkness my old friend",
},
};
const to_bytes = try msg.toBytes(gpa);
defer gpa.free(to_bytes);
const from_bytes = try Message.fromBytes(to_bytes, gpa);
defer from_bytes.deinit(gpa);
try std.testing.expectEqualDeep(msg, from_bytes);
}
test { test {
std.testing.refAllDeclsRecursive(@This()); std.testing.refAllDeclsRecursive(@This());
} }