6 Commits

Author SHA1 Message Date
56e72928c6 fix use after free 2025-05-10 21:46:53 -04:00
a80c9abfe7 Attempt to base64 encode the connection payload
For some reason I am still getting this:

2025/05/10 16:37:06 Error decoding message: SGVsbG8gZGFya25lc3MgbXkgb2xkIGZyaWVuZA==::53475673624738675a4746796132356c63334d6762586b676232786b49475a79615756755a413d3daaaa
2025-05-10 21:46:53 -04:00
245dab4909 Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
2025-05-10 21:46:53 -04:00
cde5c3626c 2025-05-10 21:46:53 -04:00
e84d1a2300 2025-05-10 21:46:53 -04:00
1b7d9bbb1a Remove bytesAsValueUnchecked
Callers can instead use std.mem.bytesAsValue directly.
2025-05-10 21:46:53 -04:00
3 changed files with 109 additions and 249 deletions

View File

@@ -1,3 +1,6 @@
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
var rand: ?Random = null;
pub fn init() !void {
@@ -14,9 +17,14 @@ pub fn deinit() void {
network.deinit();
}
fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocator) !void {
const msg_bytes = try msg.toBytes(allocator);
defer allocator.free(msg_bytes);
fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void {
if (false) {
var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev
defer foo.deinit();
}
const msg_bytes = msg.asBytes();
try msg.networkFromNativeEndian();
defer msg.nativeFromNetworkEndian() catch unreachable;
var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close();
@@ -40,14 +48,22 @@ fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocato
}
pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void {
const msg = SaprusMessage{
.relay = .{
.header = .{ .dest = dest },
.payload = payload,
},
};
const msg_bytes = try allocator.alignedAlloc(
u8,
@alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.relay,
base64Enc.calcSize(payload.len),
),
);
defer allocator.free(msg_bytes);
const msg: *SaprusMessage = .init(.relay, msg_bytes);
try broadcastSaprusMessage(msg, 8888, allocator);
const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = dest;
_ = base64Enc.encode(relay.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
}
fn randomPort() u16 {
@@ -59,34 +75,36 @@ fn randomPort() u16 {
return p;
}
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !SaprusMessage {
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage {
const dest_port = randomPort();
const msg = SaprusMessage{
.connection = .{
.header = .{
.src_port = initial_port,
.dest_port = dest_port,
},
.payload = payload,
},
};
const msg_bytes = try allocator.alignedAlloc(
u8,
@alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.connection,
base64Enc.calcSize(payload.len),
),
);
try broadcastSaprusMessage(msg, 8888, allocator);
const msg: *SaprusMessage = .init(.connection, msg_bytes);
const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = initial_port;
connection.dest_port = dest_port;
_ = base64Enc.encode(connection.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
return msg;
}
pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev
defer foo.deinit();
var initial_port: u16 = 0;
if (rand) |r| {
initial_port = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable;
var initial_conn_res: ?SaprusMessage = null;
errdefer if (initial_conn_res) |c| c.deinit(allocator);
var initial_conn_res: ?*SaprusMessage = null;
var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close();
@@ -102,15 +120,17 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
try sock.bind(bind_addr);
const msg = try sendInitialConnection(payload, initial_port, allocator);
defer allocator.free(msg.asBytes());
var response_buf: [4096]u8 = undefined;
var response_buf: [4096]u8 align(@alignOf(SaprusMessage)) = undefined;
_ = try sock.receive(&response_buf); // Ignore message that I sent.
const len = try sock.receive(&response_buf);
initial_conn_res = try SaprusMessage.fromBytes(response_buf[0..len], allocator);
std.debug.print("response bytes: {x}\n", .{response_buf[0..len]});
initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]);
// Complete handshake after awaiting response
try broadcastSaprusMessage(msg, randomPort(), allocator);
try broadcastSaprusMessage(msg, randomPort());
if (false) {
return initial_conn_res.?;

View File

@@ -63,6 +63,7 @@ pub fn main() !void {
error.WouldBlock => null,
else => return err,
};
return;
}
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});

View File

@@ -1,6 +1,3 @@
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
/// Type tag for Message union.
/// This is the first value in the actual packet sent over the network.
pub const PacketType = enum(u16) {
@@ -23,43 +20,25 @@ pub const ConnectionOptions = packed struct(u8) {
opt8: bool = false,
};
pub const Error = error{
pub const MessageTypeError = error{
NotImplementedSaprusType,
UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage,
};
// ZERO COPY STUFF
// &payload could be a void value that is treated as a pointer to a [*]u8
pub const ZeroCopyMessage = packed struct {
/// All Saprus messages
pub const Message = packed struct {
const Relay = packed struct {
dest: @Vector(4, u8),
payload: void,
pub fn getPayload(self: *align(@alignOf(ZeroCopyMessage)) Relay) []u8 {
// Cast the 'self' pointer (which points to the Relay header,
// located at the same memory as the parent's 'bytes' field)
// to a pointer to void, as required by @fieldParentPtr for a void field.
// Preserve the known alignment.
const self_as_void_ptr: *align(@alignOf(ZeroCopyMessage)) void = @ptrCast(self);
// Cast the resulting *void pointer to the parent type *ZeroCopyMessage.
// This cast performs the necessary alignment check.
const parent: *ZeroCopyMessage = @alignCast(@fieldParentPtr("bytes", self_as_void_ptr));
// The 'length' field in the parent ZeroCopyMessage contains
// the size of the header (Relay) + payload length.
const total_len = parent.length;
// Payload length = total_len - size of the Relay header
const payload_len = total_len - @sizeOf(Relay);
// The payload starts immediately after the fixed fields of the Relay struct.
// The address of the 'payload' field represents this starting point.
const payload_start_ptr: [*]u8 = @ptrCast(&self.payload);
// Return a slice from the payload start address with the calculated length.
return payload_start_ptr[0..payload_len];
pub fn getPayload(self: *align(1) Relay) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Relay)];
}
};
const Connection = packed struct {
@@ -76,14 +55,14 @@ pub const ZeroCopyMessage = packed struct {
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Connection)];
}
fn nativeFromNetworkEndian(self: *align(1) Connection) Error!void {
fn nativeFromNetworkEndian(self: *align(1) Connection) void {
self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port);
self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id);
}
fn networkFromNativeEndian(self: *align(1) Connection) Error!void {
fn networkFromNativeEndian(self: *align(1) Connection) void {
self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port);
self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num);
@@ -92,70 +71,82 @@ pub const ZeroCopyMessage = packed struct {
};
const Self = @This();
const SelfBytes = []align(@alignOf(Self)) u8;
type: PacketType,
length: u16,
bytes: void = {},
pub fn init(allocator: Allocator, comptime @"type": PacketType, payload_len: u16) !*Self {
const header_size = @sizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
else => return error.Bad,
});
const size = payload_len + @sizeOf(Self) + header_size;
const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size);
/// Takes a byte slice, and returns a Message struct backed by the slice.
/// This properly initializes the top level headers within the slice.
pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self {
std.debug.assert(bytes.len >= @sizeOf(Self));
const res: *Self = @ptrCast(bytes.ptr);
res.type = @"type";
res.length = payload_len + header_size;
res.length = @intCast(bytes.len - @sizeOf(Self));
return res;
}
pub fn deinit(self: *Self, allocator: Allocator) void {
allocator.free(self.asBytes());
pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
std.debug.assert(payload_len < std.math.maxInt(u16));
const header_size = @sizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType,
});
return @intCast(payload_len + @sizeOf(Self) + header_size);
}
fn getRelay(self: *Self) *align(@alignOf(Self)) Relay {
fn getRelay(self: *Self) *align(1) Relay {
return std.mem.bytesAsValue(Relay, &self.bytes);
}
fn getConnection(self: *Self) *align(@alignOf(Self)) Connection {
fn getConnection(self: *Self) *align(1) Connection {
return std.mem.bytesAsValue(Connection, &self.bytes);
}
pub fn getSaprusTypePayload(self: *Self) Error!(union(PacketType) {
relay: *align(@alignOf(Self)) Relay,
pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {
relay: *align(1) Relay,
file_transfer: void,
connection: *align(@alignOf(Self)) Connection,
connection: *align(1) Connection,
}) {
return switch (self.type) {
.relay => .{ .relay = self.getRelay() },
.connection => .{ .connection = self.getConnection() },
.file_transfer => Error.NotImplementedSaprusType,
else => Error.UnknownSaprusType,
.file_transfer => MessageTypeError.NotImplementedSaprusType,
else => MessageTypeError.UnknownSaprusType,
};
}
pub fn nativeFromNetworkEndian(self: *Self) Error!void {
pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {
self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = bigToNative(@TypeOf(self.length), self.length);
errdefer {
// If the payload specific headers fail, revert the top level header values
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
switch (try self.getSaprusTypePayload()) {
.relay => {},
.connection => |*con| try con.*.nativeFromNetworkEndian(),
.connection => |*con| con.*.nativeFromNetworkEndian(),
// We know other values are unreachable,
// because they would have returned an error from the switch condition.
else => unreachable,
}
}
pub fn networkFromNativeEndian(self: *Self) Error!void {
pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {
try switch (try self.getSaprusTypePayload()) {
.relay => {},
.connection => |*con| con.*.networkFromNativeEndian(),
.file_transfer => Error.NotImplementedSaprusType,
else => Error.UnknownSaprusType,
.file_transfer => MessageTypeError.NotImplementedSaprusType,
else => MessageTypeError.UnknownSaprusType,
};
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@@ -164,23 +155,9 @@ pub const ZeroCopyMessage = packed struct {
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
pub fn bytesAsValueUnchecked(bytes: []align(@alignOf(Self)) u8) *Self {
return std.mem.bytesAsValue(Self, bytes);
}
pub fn bytesAsValue(bytes: []align(@alignOf(Self)) u8) !*Self {
const res = bytesAsValueUnchecked(bytes);
return switch (res.type) {
.relay, .connection => if (bytes.len == res.length + @sizeOf(Self))
res
else
Error.InvalidMessage,
.file_transfer => Error.NotImplementedSaprusType,
else => Error.UnknownSaprusType,
};
}
pub fn asBytes(self: *Self) []align(@alignOf(Self)) u8 {
/// Deprecated.
/// If I need the bytes, I should just pass around the slice that is backing this to begin with.
pub fn asBytes(self: *Self) SelfBytes {
const size = @sizeOf(Self) + self.length;
return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
}
@@ -190,180 +167,42 @@ test "testing variable length zero copy struct" {
const gpa = std.testing.allocator;
const payload = "Hello darkness my old friend";
// Create a view of the byte slice as a ZeroCopyMessage
const zcm: *ZeroCopyMessage = try .init(gpa, .relay, payload.len);
defer zcm.deinit(gpa);
std.debug.print("outer: {*}\n", .{zcm});
// Create a view of the byte slice as a Message
const msg: *Message = try .init(gpa, .relay, payload.len);
defer msg.deinit(gpa);
{
// Set the message values
{
// These are both set by the init call.
// zcm.type = .relay;
// zcm.length = payload_len;
// msg.type = .relay;
// msg.length = payload_len;
}
const relay = (try zcm.getSaprusTypePayload()).relay;
const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = .{ 1, 2, 3, 4 };
@memcpy(relay.getPayload(), payload);
}
{
const bytes = zcm.asBytes();
const bytes = msg.asBytes();
// Print the message as hex using the network byte order
try zcm.networkFromNativeEndian();
try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian.
defer zcm.nativeFromNetworkEndian() catch unreachable;
defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("network bytes: {x}\n", .{bytes});
std.debug.print("bytes len: {d}\n", .{bytes.len});
}
if (false) {
// Illegal behavior
std.debug.print("{any}\n", .{(try zcm.getSaprusTypePayload()).connection});
std.debug.print("{any}\n", .{(try msg.getSaprusTypePayload()).connection});
}
try std.testing.expectEqualDeep(zcm, try ZeroCopyMessage.bytesAsValue(zcm.asBytes()));
try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes()));
}
/// All Saprus messages
pub const Message = union(PacketType) {
pub const Relay = struct {
pub const Header = packed struct {
dest: @Vector(4, u8),
};
header: Header,
payload: []const u8,
};
pub const Connection = struct {
pub const Header = packed struct {
src_port: u16, // random number > 1024
dest_port: u16, // random number > 1024
seq_num: u32 = 0,
msg_id: u32 = 0,
reserved: u8 = 0,
options: ConnectionOptions = .{},
};
header: Header,
payload: []const u8,
};
relay: Relay,
file_transfer: void, // unimplemented
connection: Connection,
/// Should be called for any Message that was declared using a function that you pass an allocator to.
pub fn deinit(self: Message, allocator: Allocator) void {
switch (self) {
.relay => |r| allocator.free(r.payload),
.connection => |c| allocator.free(c.payload),
else => unreachable,
}
}
fn toBytesAux(
header: anytype,
payload: []const u8,
buf: *std.ArrayList(u8),
allocator: Allocator,
) !void {
const Header = @TypeOf(header);
// Create a growable string to store the base64 bytes in.
// Doing this first so I can use the length of the encoded bytes for the length field.
var payload_list = std.ArrayList(u8).init(allocator);
defer payload_list.deinit();
const buf_w = payload_list.writer();
// Write the payload bytes as base64 to the growable string.
try base64Enc.encodeWriter(buf_w, payload);
// At this point, payload_list contains the base64 encoded payload.
// Add the payload length to the output buf.
try buf.*.appendSlice(
asBytes(&nativeToBig(u16, @intCast(payload_list.items.len + @bitSizeOf(Header) / 8))),
);
// Add the header bytes to the output buf.
var header_buf: [@sizeOf(Header)]u8 = undefined;
var header_buf_stream = std.io.fixedBufferStream(&header_buf);
try header_buf_stream.writer().writeStructEndian(header, .big);
// Add the exact number of bits in the header without padding.
try buf.*.appendSlice(header_buf[0 .. @bitSizeOf(Header) / 8]);
try buf.*.appendSlice(payload_list.items);
}
/// Caller is responsible for freeing the returned bytes.
pub fn toBytes(self: Message, allocator: Allocator) ![]u8 {
// Create a growable list of bytes to store the output in.
var buf = std.ArrayList(u8).init(allocator);
errdefer buf.deinit();
// Start with writing the message type, which is the first 16 bits of every Saprus message.
try buf.appendSlice(asBytes(&nativeToBig(u16, @intFromEnum(self))));
// Write the proper header and payload for the given packet type.
switch (self) {
.relay => |r| try toBytesAux(r.header, r.payload, &buf, allocator),
.connection => |c| try toBytesAux(c.header, c.payload, &buf, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
}
// Collect the growable list as a slice and return it.
return buf.toOwnedSlice();
}
fn fromBytesAux(
comptime packet: PacketType,
len: u16,
r: std.io.FixedBufferStream([]const u8).Reader,
allocator: Allocator,
) !Message {
const Header = @field(@FieldType(Message, @tagName(packet)), "Header");
// Read the header for the current message type.
var header_bytes: [@sizeOf(Header)]u8 = undefined;
_ = try r.read(header_bytes[0 .. @bitSizeOf(Header) / 8]);
var header_stream = std.io.fixedBufferStream(&header_bytes);
const header = try header_stream.reader().readStructEndian(Header, .big);
// Read the base64 bytes into a list to be able to call the decoder on it.
const payload_buf = try allocator.alloc(u8, len - @bitSizeOf(Header) / 8);
defer allocator.free(payload_buf);
_ = try r.readAll(payload_buf);
// Create a buffer to store the payload in, and decode the base64 bytes into the payload field.
const payload = try allocator.alloc(u8, try base64Dec.calcSizeForSlice(payload_buf));
try base64Dec.decode(payload, payload_buf);
// Return the type of Message specified by the `packet` argument.
return @unionInit(Message, @tagName(packet), .{
.header = header,
.payload = payload,
});
}
/// Caller is responsible for calling .deinit on the returned value.
pub fn fromBytes(bytes: []const u8, allocator: Allocator) !Message {
var s = std.io.fixedBufferStream(bytes);
const r = s.reader();
// Read packet type
const packet_type = @as(PacketType, @enumFromInt(try r.readInt(u16, .big)));
// Read the length of the header + base64 encoded payload.
const len = try r.readInt(u16, .big);
switch (packet_type) {
.relay => return fromBytesAux(.relay, len, r, allocator),
.connection => return fromBytesAux(.connection, len, r, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
else => return Error.UnknownSaprusType,
}
}
};
const std = @import("std");
const Allocator = std.mem.Allocator;