13 Commits

Author SHA1 Message Date
56b6b8a386 Use Client as var type instead of singleton 2025-05-11 13:52:42 -04:00
14ed0bc3f3 Fix issue returning stack pointer 2025-05-11 13:40:55 -04:00
c72503fce6 Fix extra bytes in connection message. 2025-05-11 13:40:23 -04:00
373dbebc8c Add broadcast initial interest using raw sockets
Use this from the relay message
2025-05-11 11:40:15 -04:00
cde289d648 Update gatorcat dep and use bytes for broadcast message
The latter is helpful for the lifetime of the message.
2025-05-11 10:12:26 -04:00
716fb466fa Remove allocation for messages 2025-05-10 21:46:53 -04:00
583f9d8b8f Add comments and fix tests
Also added networkBytesAsValue and restored bytesAsValue.
These are useful for treating the bytes from the network directly as a Message.
Otherwise, the init function would overwrite the packet type and length to be correct.
I would like the message handling to fail if the message body is incorrect.
2025-05-10 21:46:53 -04:00
56e72928c6 fix use after free 2025-05-10 21:46:53 -04:00
a80c9abfe7 Attempt to base64 encode the connection payload
For some reason I am still getting this:

2025/05/10 16:37:06 Error decoding message: SGVsbG8gZGFya25lc3MgbXkgb2xkIGZyaWVuZA==::53475673624738675a4746796132356c63334d6762586b676232786b49475a79615756755a413d3daaaa
2025-05-10 21:46:53 -04:00
245dab4909 Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
2025-05-10 21:46:53 -04:00
cde5c3626c 2025-05-10 21:46:53 -04:00
e84d1a2300 2025-05-10 21:46:53 -04:00
1b7d9bbb1a Remove bytesAsValueUnchecked
Callers can instead use std.mem.bytesAsValue directly.
2025-05-10 21:46:53 -04:00
4 changed files with 277 additions and 288 deletions

View File

@@ -45,8 +45,8 @@
.hash = "clap-0.10.0-oBajB434AQBDh-Ei3YtoKIRxZacVPF1iSwp3IX_ZB8f0", .hash = "clap-0.10.0-oBajB434AQBDh-Ei3YtoKIRxZacVPF1iSwp3IX_ZB8f0",
}, },
.gatorcat = .{ .gatorcat = .{
.url = "git+https://github.com/kj4tmp/gatorcat#bb1847f6c95852e7a0ec8c07870a948c171d5f98", .url = "git+https://github.com/kj4tmp/gatorcat.git#0a97b666677501db4939e3e8245f88a19e015893",
.hash = "gatorcat-0.3.2-WcrpTf1mBwDrmPaIhKCfLJO064v8Sjjn7DBq4CKZSgHH", .hash = "gatorcat-0.3.4-WcrpTcleBwCta_9TjomuIGb3bdg2Pke_FXI_WkMTEivH",
}, },
}, },
.paths = .{ .paths = .{

View File

@@ -1,22 +1,88 @@
var rand: ?Random = null; const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
pub fn init() !void { rand: Random,
socket: gcat.nic.RawSocket,
const Self = @This();
const max_message_size = 2048;
pub fn init(interface_name: [:0]const u8) !Self {
var prng = Random.DefaultPrng.init(blk: { var prng = Random.DefaultPrng.init(blk: {
var seed: u64 = undefined; var seed: u64 = undefined;
try posix.getrandom(mem.asBytes(&seed)); try posix.getrandom(mem.asBytes(&seed));
break :blk seed; break :blk seed;
}); });
rand = prng.random(); const rand = prng.random();
try network.init();
const socket: gcat.nic.RawSocket = try .init(interface_name);
return .{
.rand = rand,
.socket = socket,
};
} }
pub fn deinit() void { pub fn deinit(self: *Self) void {
network.deinit(); self.socket.deinit();
} }
fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocator) !void { /// Used for relay messages and connection handshake.
const msg_bytes = try msg.toBytes(allocator); /// Assumes Client .init has been called.
defer allocator.free(msg_bytes); fn broadcastInitialInterestMessage(self: *Self, msg_bytes: []align(@alignOf(SaprusMessage)) u8) !void {
var packet_bytes: [max_message_size]u8 = comptime blk: {
var b: [max_message_size]u8 = @splat(0);
// Destination MAC addr to FF:FF:FF:FF:FF:FF
for (0..6) |i| {
b[i] = 0xff;
}
// Set Ethernet type to IPv4
b[0x0c] = 0x08;
b[0x0d] = 0x00;
// Set IPv4 version to 4
b[0x0e] = 0x45;
// Destination broadcast
for (0x1e..0x22) |i| {
b[i] = 0xff;
}
// Set TTL
b[0x16] = 0x40;
// Set IPv4 protocol to UDP
b[0x17] = 0x11;
// Set interest filter value to 8888.
b[0x24] = 0x22;
b[0x25] = 0xb8;
break :blk b;
};
var msg: *SaprusMessage = try .bytesAsValue(msg_bytes);
try msg.networkFromNativeEndian();
defer msg.nativeFromNetworkEndian() catch unreachable;
// The byte position within the packet that the saprus message starts at.
const saprus_start_byte = 42;
@memcpy(packet_bytes[saprus_start_byte .. saprus_start_byte + msg_bytes.len], msg_bytes);
try self.socket.linkLayer().send(packet_bytes[0 .. saprus_start_byte + msg_bytes.len]);
}
// fn broadcastSaprusMessage(msg_bytes: []align(@alignOf(SaprusMessage)) u8) !void {}
fn broadcastSaprusMessage(msg_bytes: []align(@alignOf(SaprusMessage)) u8, udp_port: u16) !void {
if (false) {
var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev
defer foo.deinit();
}
const msg: *SaprusMessage = try .bytesAsValue(msg_bytes);
try msg.networkFromNativeEndian();
defer msg.nativeFromNetworkEndian() catch unreachable;
var sock = try network.Socket.create(.ipv4, .udp); var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close(); defer sock.close();
@@ -36,57 +102,57 @@ fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocato
try sock.bind(bind_addr); try sock.bind(bind_addr);
std.debug.print("{x}\n", .{msg_bytes});
_ = try sock.sendTo(dest_addr, msg_bytes); _ = try sock.sendTo(dest_addr, msg_bytes);
} }
pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void { pub fn sendRelay(self: *Self, payload: []const u8, dest: [4]u8) !void {
const msg = SaprusMessage{ var buf: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined;
.relay = .{ const msg_bytes = buf[0..try SaprusMessage.calcSize(
.header = .{ .dest = dest }, .relay,
.payload = payload, base64Enc.calcSize(payload.len),
}, )];
}; const msg: *SaprusMessage = .init(.relay, msg_bytes);
try broadcastSaprusMessage(msg, 8888, allocator); const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = dest;
_ = base64Enc.encode(relay.getPayload(), payload);
try self.broadcastInitialInterestMessage(msg_bytes);
} }
fn randomPort() u16 { fn randomPort(self: Self) u16 {
var p: u16 = 0; return self.rand.intRangeAtMost(u16, 1024, 65000);
if (rand) |r| {
p = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable;
return p;
} }
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !SaprusMessage { pub fn sendInitialConnection(
const dest_port = randomPort(); self: Self,
const msg = SaprusMessage{ payload: []const u8,
.connection = .{ output_bytes: []align(@alignOf(SaprusMessage)) u8,
.header = .{ initial_port: u16,
.src_port = initial_port, ) !*SaprusMessage {
.dest_port = dest_port, const dest_port = self.randomPort();
}, const msg_bytes = output_bytes[0..try SaprusMessage.calcSize(
.payload = payload, .connection,
}, base64Enc.calcSize(payload.len),
}; )];
const msg: *SaprusMessage = .init(.connection, msg_bytes);
try broadcastSaprusMessage(msg, 8888, allocator); const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = initial_port;
connection.dest_port = dest_port;
_ = base64Enc.encode(connection.getPayload(), payload);
try broadcastSaprusMessage(msg_bytes, 8888);
return msg; return msg;
} }
pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection { pub fn connect(self: Self, payload: []const u8) !?SaprusConnection {
var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev const initial_port = self.randomPort();
defer foo.deinit();
var initial_port: u16 = 0; var initial_conn_res: ?*SaprusMessage = null;
if (rand) |r| {
initial_port = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable;
var initial_conn_res: ?SaprusMessage = null;
errdefer if (initial_conn_res) |c| c.deinit(allocator);
var sock = try network.Socket.create(.ipv4, .udp); var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close(); defer sock.close();
@@ -101,16 +167,17 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
try sock.setReadTimeout(1 * std.time.us_per_s); try sock.setReadTimeout(1 * std.time.us_per_s);
try sock.bind(bind_addr); try sock.bind(bind_addr);
const msg = try sendInitialConnection(payload, initial_port, allocator); var sent_msg_bytes: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined;
const msg = try self.sendInitialConnection(payload, &sent_msg_bytes, initial_port);
var response_buf: [4096]u8 = undefined; var response_buf: [max_message_size]u8 align(@alignOf(SaprusMessage)) = undefined;
_ = try sock.receive(&response_buf); // Ignore message that I sent. _ = try sock.receive(&response_buf); // Ignore message that I sent.
const len = try sock.receive(&response_buf); const len = try sock.receive(&response_buf);
initial_conn_res = try SaprusMessage.fromBytes(response_buf[0..len], allocator); initial_conn_res = try .networkBytesAsValue(response_buf[0..len]);
// Complete handshake after awaiting response // Complete handshake after awaiting response
try broadcastSaprusMessage(msg, randomPort(), allocator); try broadcastSaprusMessage(msg.asBytes(), self.randomPort());
if (false) { if (false) {
return initial_conn_res.?; return initial_conn_res.?;
@@ -128,5 +195,3 @@ const mem = std.mem;
const network = @import("network"); const network = @import("network");
const gcat = @import("gatorcat"); const gcat = @import("gatorcat");
const Allocator = mem.Allocator;

View File

@@ -42,27 +42,27 @@ pub fn main() !void {
}; };
defer res.deinit(); defer res.deinit();
try SaprusClient.init();
defer SaprusClient.deinit();
if (res.args.help != 0) { if (res.args.help != 0) {
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{}); return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});
} }
var client = try SaprusClient.init("enp7s0");
defer client.deinit();
if (res.args.relay) |r| { if (res.args.relay) |r| {
const dest = parseDest(res.args.dest); const dest = parseDest(res.args.dest);
try SaprusClient.sendRelay( try client.sendRelay(
if (r.len > 0) r else "Hello darkness my old friend", if (r.len > 0) r else "Hello darkness my old friend",
dest, dest,
gpa,
); );
// std.debug.print("Sent: {s}\n", .{r}); // std.debug.print("Sent: {s}\n", .{r});
return; return;
} else if (res.args.connect) |c| { } else if (res.args.connect) |c| {
_ = SaprusClient.connect(if (c.len > 0) c else "Hello darkness my old friend", gpa) catch |err| switch (err) { _ = client.connect(if (c.len > 0) c else "Hello darkness my old friend") catch |err| switch (err) {
error.WouldBlock => null, error.WouldBlock => null,
else => return err, else => return err,
}; };
return;
} }
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{}); return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});

View File

@@ -1,6 +1,3 @@
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
/// Type tag for Message union. /// Type tag for Message union.
/// This is the first value in the actual packet sent over the network. /// This is the first value in the actual packet sent over the network.
pub const PacketType = enum(u16) { pub const PacketType = enum(u16) {
@@ -23,22 +20,25 @@ pub const ConnectionOptions = packed struct(u8) {
opt8: bool = false, opt8: bool = false,
}; };
pub const Error = error{ pub const MessageTypeError = error{
NotImplementedSaprusType, NotImplementedSaprusType,
UnknownSaprusType, UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage, InvalidMessage,
}; };
// ZERO COPY STUFF // ZERO COPY STUFF
// &payload could be a void value that is treated as a pointer to a [*]u8 // &payload could be a void value that is treated as a pointer to a [*]u8
pub const ZeroCopyMessage = packed struct { /// All Saprus messages
pub const Message = packed struct {
const Relay = packed struct { const Relay = packed struct {
dest: @Vector(4, u8), dest: @Vector(4, u8),
payload: void, payload: void,
pub fn getPayload(self: *align(1) Relay) []u8 { pub fn getPayload(self: *align(1) Relay) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16)); const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Relay)]; return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @bitSizeOf(Relay) / 8];
} }
}; };
const Connection = packed struct { const Connection = packed struct {
@@ -52,17 +52,17 @@ pub const ZeroCopyMessage = packed struct {
pub fn getPayload(self: *align(1) Connection) []u8 { pub fn getPayload(self: *align(1) Connection) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16)); const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Connection)]; return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @bitSizeOf(Connection) / 8];
} }
fn nativeFromNetworkEndian(self: *align(1) Connection) Error!void { fn nativeFromNetworkEndian(self: *align(1) Connection) void {
self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port); self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port);
self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port); self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num); self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id); self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id);
} }
fn networkFromNativeEndian(self: *align(1) Connection) Error!void { fn networkFromNativeEndian(self: *align(1) Connection) void {
self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port); self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port);
self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port); self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num); self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num);
@@ -71,27 +71,33 @@ pub const ZeroCopyMessage = packed struct {
}; };
const Self = @This(); const Self = @This();
const SelfBytes = []align(@alignOf(Self)) u8;
type: PacketType, type: PacketType,
length: u16, length: u16,
bytes: void = {}, bytes: void = {},
pub fn init(allocator: Allocator, comptime @"type": PacketType, payload_len: u16) !*Self { /// Takes a byte slice, and returns a Message struct backed by the slice.
const header_size = @sizeOf(switch (@"type") { /// This properly initializes the top level headers within the slice.
.relay => Relay, /// This is used for creating new messages. For reading messages from the network,
.connection => Connection, /// see: networkBytesAsValue.
else => return error.Bad, pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self {
}); std.debug.assert(bytes.len >= @sizeOf(Self));
const size = payload_len + @sizeOf(Self) + header_size;
const bytes = try allocator.alignedAlloc(u8, @alignOf(Self), size);
const res: *Self = @ptrCast(bytes.ptr); const res: *Self = @ptrCast(bytes.ptr);
res.type = @"type"; res.type = @"type";
res.length = payload_len + header_size; res.length = @intCast(bytes.len - @sizeOf(Self));
return res; return res;
} }
pub fn deinit(self: *Self, allocator: Allocator) void { /// Compute the number of bytes required to store a given payload size for a given message type.
allocator.free(self.asBytes()); pub fn calcSize(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
const header_size = @bitSizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType,
}) / 8;
return @intCast(payload_len + @sizeOf(Self) + header_size);
} }
fn getRelay(self: *Self) *align(1) Relay { fn getRelay(self: *Self) *align(1) Relay {
@@ -101,7 +107,8 @@ pub const ZeroCopyMessage = packed struct {
return std.mem.bytesAsValue(Connection, &self.bytes); return std.mem.bytesAsValue(Connection, &self.bytes);
} }
pub fn getSaprusTypePayload(self: *Self) Error!(union(PacketType) { /// Access the message Saprus payload.
pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {
relay: *align(1) Relay, relay: *align(1) Relay,
file_transfer: void, file_transfer: void,
connection: *align(1) Connection, connection: *align(1) Connection,
@@ -109,32 +116,42 @@ pub const ZeroCopyMessage = packed struct {
return switch (self.type) { return switch (self.type) {
.relay => .{ .relay = self.getRelay() }, .relay => .{ .relay = self.getRelay() },
.connection => .{ .connection = self.getConnection() }, .connection => .{ .connection = self.getConnection() },
.file_transfer => Error.NotImplementedSaprusType, .file_transfer => MessageTypeError.NotImplementedSaprusType,
else => Error.UnknownSaprusType, else => MessageTypeError.UnknownSaprusType,
}; };
} }
pub fn nativeFromNetworkEndian(self: *Self) Error!void { /// Convert the message to native endianness from network endianness in-place.
pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {
self.type = @enumFromInt(bigToNative( self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type, @typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type), @intFromEnum(self.type),
)); ));
self.length = bigToNative(@TypeOf(self.length), self.length); self.length = bigToNative(@TypeOf(self.length), self.length);
errdefer {
// If the payload specific headers fail, revert the top level header values
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
switch (try self.getSaprusTypePayload()) { switch (try self.getSaprusTypePayload()) {
.relay => {}, .relay => {},
.connection => |*con| try con.*.nativeFromNetworkEndian(), .connection => |*con| con.*.nativeFromNetworkEndian(),
// We know other values are unreachable, // We know other values are unreachable,
// because they would have returned an error from the switch condition. // because they would have returned an error from the switch condition.
else => unreachable, else => unreachable,
} }
} }
pub fn networkFromNativeEndian(self: *Self) Error!void { /// Convert the message to network endianness from native endianness in-place.
pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {
try switch (try self.getSaprusTypePayload()) { try switch (try self.getSaprusTypePayload()) {
.relay => {}, .relay => {},
.connection => |*con| con.*.networkFromNativeEndian(), .connection => |*con| con.*.networkFromNativeEndian(),
.file_transfer => Error.NotImplementedSaprusType, .file_transfer => MessageTypeError.NotImplementedSaprusType,
else => Error.UnknownSaprusType, else => MessageTypeError.UnknownSaprusType,
}; };
self.type = @enumFromInt(nativeToBig( self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type, @typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@@ -143,204 +160,105 @@ pub const ZeroCopyMessage = packed struct {
self.length = nativeToBig(@TypeOf(self.length), self.length); self.length = nativeToBig(@TypeOf(self.length), self.length);
} }
pub fn bytesAsValueUnchecked(bytes: []align(@alignOf(Self)) u8) *Self { /// Convert network endian bytes to a native endian value in-place.
return std.mem.bytesAsValue(Self, bytes); pub fn networkBytesAsValue(bytes: SelfBytes) MessageParseError!*Self {
const res = std.mem.bytesAsValue(Self, bytes);
try res.nativeFromNetworkEndian();
return .bytesAsValue(bytes);
} }
pub fn bytesAsValue(bytes: []align(@alignOf(Self)) u8) !*Self { /// Create a structured view of the bytes without initializing the length or type,
const res = bytesAsValueUnchecked(bytes); /// and without converting the endianness.
pub fn bytesAsValue(bytes: SelfBytes) MessageParseError!*Self {
const res = std.mem.bytesAsValue(Self, bytes);
return switch (res.type) { return switch (res.type) {
.relay, .connection => if (bytes.len == res.length + @sizeOf(Self)) .relay, .connection => if (bytes.len == res.length + @sizeOf(Self))
res res
else else
Error.InvalidMessage, MessageParseError.InvalidMessage,
.file_transfer => Error.NotImplementedSaprusType, .file_transfer => MessageParseError.NotImplementedSaprusType,
else => Error.UnknownSaprusType, else => MessageParseError.UnknownSaprusType,
}; };
} }
pub fn asBytes(self: *Self) []align(@alignOf(Self)) u8 { /// Deprecated.
/// If I need the bytes, I should just pass around the slice that is backing this to begin with.
pub fn asBytes(self: *Self) SelfBytes {
const size = @sizeOf(Self) + self.length; const size = @sizeOf(Self) + self.length;
return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size]; return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
} }
}; };
test "testing variable length zero copy struct" { test "testing variable length zero copy struct" {
const gpa = std.testing.allocator; {
// Relay test
const payload = "Hello darkness my old friend"; const payload = "Hello darkness my old friend";
var msg_bytes: [try Message.calcSize(.relay, payload.len)]u8 align(@alignOf(Message)) = undefined;
// Create a view of the byte slice as a ZeroCopyMessage // Create a view of the byte slice as a Message
const zcm: *ZeroCopyMessage = try .init(gpa, .relay, payload.len); const msg: *Message = .init(.relay, &msg_bytes);
defer zcm.deinit(gpa);
{ {
// Set the message values // Set the message values
{ {
// These are both set by the init call. // These are both set by the init call.
// zcm.type = .relay; // msg.type = .relay;
// zcm.length = payload_len; // msg.length = payload_len;
} }
const relay = (try zcm.getSaprusTypePayload()).relay; const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = .{ 1, 2, 3, 4 }; relay.dest = .{ 1, 2, 3, 4 };
@memcpy(relay.getPayload(), payload); @memcpy(relay.getPayload(), payload);
} }
{ {
const bytes = zcm.asBytes();
// Print the message as hex using the network byte order // Print the message as hex using the network byte order
try zcm.networkFromNativeEndian(); try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because // We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian. // it would have returned an error from networkFromNativeEndian.
defer zcm.nativeFromNetworkEndian() catch unreachable; defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("network bytes: {x}\n", .{bytes}); std.debug.print("relay network bytes: {x}\n", .{msg_bytes});
std.debug.print("bytes len: {d}\n", .{bytes.len}); std.debug.print("bytes len: {d}\n", .{msg_bytes.len});
} }
if (false) { if (false) {
// Illegal behavior // Illegal behavior
std.debug.print("{any}\n", .{(try zcm.getSaprusTypePayload()).connection}); std.debug.print("{any}\n", .{(try msg.getSaprusTypePayload()).connection});
} }
try std.testing.expectEqualDeep(zcm, try ZeroCopyMessage.bytesAsValue(zcm.asBytes())); try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes()));
} }
/// All Saprus messages {
pub const Message = union(PacketType) { // Connection test
pub const Relay = struct { const payload = "Hello darkness my old friend";
pub const Header = packed struct { var msg_bytes: [try Message.calcSize(.connection, payload.len)]u8 align(@alignOf(Message)) = undefined;
dest: @Vector(4, u8),
};
header: Header,
payload: []const u8,
};
pub const Connection = struct {
pub const Header = packed struct {
src_port: u16, // random number > 1024
dest_port: u16, // random number > 1024
seq_num: u32 = 0,
msg_id: u32 = 0,
reserved: u8 = 0,
options: ConnectionOptions = .{},
};
header: Header,
payload: []const u8,
};
relay: Relay,
file_transfer: void, // unimplemented
connection: Connection,
/// Should be called for any Message that was declared using a function that you pass an allocator to. // Create a view of the byte slice as a Message
pub fn deinit(self: Message, allocator: Allocator) void { const msg: *Message = .init(.connection, &msg_bytes);
switch (self) {
.relay => |r| allocator.free(r.payload), {
.connection => |c| allocator.free(c.payload), // Initializing connection header values
else => unreachable, const connection = (try msg.getSaprusTypePayload()).connection;
} connection.src_port = 1;
connection.dest_port = 2;
connection.seq_num = 3;
connection.msg_id = 4;
connection.reserved = 5;
connection.options = @bitCast(@as(u8, 6));
@memcpy(connection.getPayload(), payload);
} }
fn toBytesAux( {
header: anytype, // Print the message as hex using the network byte order
payload: []const u8, try msg.networkFromNativeEndian();
buf: *std.ArrayList(u8), // We know the error from nativeFromNetworkEndian is unreachable because
allocator: Allocator, // it would have returned an error from networkFromNativeEndian.
) !void { defer msg.nativeFromNetworkEndian() catch unreachable;
const Header = @TypeOf(header); std.debug.print("connection network bytes: {x}\n", .{msg_bytes});
// Create a growable string to store the base64 bytes in. std.debug.print("bytes len: {d}\n", .{msg_bytes.len});
// Doing this first so I can use the length of the encoded bytes for the length field. }
var payload_list = std.ArrayList(u8).init(allocator);
defer payload_list.deinit();
const buf_w = payload_list.writer();
// Write the payload bytes as base64 to the growable string.
try base64Enc.encodeWriter(buf_w, payload);
// At this point, payload_list contains the base64 encoded payload.
// Add the payload length to the output buf.
try buf.*.appendSlice(
asBytes(&nativeToBig(u16, @intCast(payload_list.items.len + @bitSizeOf(Header) / 8))),
);
// Add the header bytes to the output buf.
var header_buf: [@sizeOf(Header)]u8 = undefined;
var header_buf_stream = std.io.fixedBufferStream(&header_buf);
try header_buf_stream.writer().writeStructEndian(header, .big);
// Add the exact number of bits in the header without padding.
try buf.*.appendSlice(header_buf[0 .. @bitSizeOf(Header) / 8]);
try buf.*.appendSlice(payload_list.items);
}
/// Caller is responsible for freeing the returned bytes.
pub fn toBytes(self: Message, allocator: Allocator) ![]u8 {
// Create a growable list of bytes to store the output in.
var buf = std.ArrayList(u8).init(allocator);
errdefer buf.deinit();
// Start with writing the message type, which is the first 16 bits of every Saprus message.
try buf.appendSlice(asBytes(&nativeToBig(u16, @intFromEnum(self))));
// Write the proper header and payload for the given packet type.
switch (self) {
.relay => |r| try toBytesAux(r.header, r.payload, &buf, allocator),
.connection => |c| try toBytesAux(c.header, c.payload, &buf, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
}
// Collect the growable list as a slice and return it.
return buf.toOwnedSlice();
}
fn fromBytesAux(
comptime packet: PacketType,
len: u16,
r: std.io.FixedBufferStream([]const u8).Reader,
allocator: Allocator,
) !Message {
const Header = @field(@FieldType(Message, @tagName(packet)), "Header");
// Read the header for the current message type.
var header_bytes: [@sizeOf(Header)]u8 = undefined;
_ = try r.read(header_bytes[0 .. @bitSizeOf(Header) / 8]);
var header_stream = std.io.fixedBufferStream(&header_bytes);
const header = try header_stream.reader().readStructEndian(Header, .big);
// Read the base64 bytes into a list to be able to call the decoder on it.
const payload_buf = try allocator.alloc(u8, len - @bitSizeOf(Header) / 8);
defer allocator.free(payload_buf);
_ = try r.readAll(payload_buf);
// Create a buffer to store the payload in, and decode the base64 bytes into the payload field.
const payload = try allocator.alloc(u8, try base64Dec.calcSizeForSlice(payload_buf));
try base64Dec.decode(payload, payload_buf);
// Return the type of Message specified by the `packet` argument.
return @unionInit(Message, @tagName(packet), .{
.header = header,
.payload = payload,
});
}
/// Caller is responsible for calling .deinit on the returned value.
pub fn fromBytes(bytes: []const u8, allocator: Allocator) !Message {
var s = std.io.fixedBufferStream(bytes);
const r = s.reader();
// Read packet type
const packet_type = @as(PacketType, @enumFromInt(try r.readInt(u16, .big)));
// Read the length of the header + base64 encoded payload.
const len = try r.readInt(u16, .big);
switch (packet_type) {
.relay => return fromBytesAux(.relay, len, r, allocator),
.connection => return fromBytesAux(.connection, len, r, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
else => return Error.UnknownSaprusType,
} }
} }
};
const std = @import("std"); const std = @import("std");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
@@ -350,6 +268,7 @@ const nativeToBig = std.mem.nativeToBig;
const bigToNative = std.mem.bigToNative; const bigToNative = std.mem.bigToNative;
test "Round trip Relay toBytes and fromBytes" { test "Round trip Relay toBytes and fromBytes" {
if (false) {
const gpa = std.testing.allocator; const gpa = std.testing.allocator;
const msg = Message{ const msg = Message{
.relay = .{ .relay = .{
@@ -366,8 +285,11 @@ test "Round trip Relay toBytes and fromBytes" {
try std.testing.expectEqualDeep(msg, from_bytes); try std.testing.expectEqualDeep(msg, from_bytes);
} }
return error.SkipZigTest;
}
test "Round trip Connection toBytes and fromBytes" { test "Round trip Connection toBytes and fromBytes" {
if (false) {
const gpa = std.testing.allocator; const gpa = std.testing.allocator;
const msg = Message{ const msg = Message{
.connection = .{ .connection = .{
@@ -387,6 +309,8 @@ test "Round trip Connection toBytes and fromBytes" {
try std.testing.expectEqualDeep(msg, from_bytes); try std.testing.expectEqualDeep(msg, from_bytes);
} }
return error.SkipZigTest;
}
test { test {
std.testing.refAllDeclsRecursive(@This()); std.testing.refAllDeclsRecursive(@This());