Fix extra bytes in connection message.

This commit is contained in:
2025-05-11 13:39:53 -04:00
parent 373dbebc8c
commit c72503fce6

View File

@@ -38,7 +38,7 @@ pub const Message = packed struct {
pub fn getPayload(self: *align(1) Relay) []u8 { pub fn getPayload(self: *align(1) Relay) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16)); const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Relay)]; return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @bitSizeOf(Relay) / 8];
} }
}; };
const Connection = packed struct { const Connection = packed struct {
@@ -52,7 +52,7 @@ pub const Message = packed struct {
pub fn getPayload(self: *align(1) Connection) []u8 { pub fn getPayload(self: *align(1) Connection) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16)); const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Connection)]; return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @bitSizeOf(Connection) / 8];
} }
fn nativeFromNetworkEndian(self: *align(1) Connection) void { fn nativeFromNetworkEndian(self: *align(1) Connection) void {
@@ -91,13 +91,12 @@ pub const Message = packed struct {
/// Compute the number of bytes required to store a given payload size for a given message type. /// Compute the number of bytes required to store a given payload size for a given message type.
pub fn calcSize(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 { pub fn calcSize(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
std.debug.assert(payload_len < std.math.maxInt(u16)); const header_size = @bitSizeOf(switch (@"type") {
const header_size = @sizeOf(switch (@"type") {
.relay => Relay, .relay => Relay,
.connection => Connection, .connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType, .file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType, else => return MessageTypeError.UnknownSaprusType,
}); }) / 8;
return @intCast(payload_len + @sizeOf(Self) + header_size); return @intCast(payload_len + @sizeOf(Self) + header_size);
} }
@@ -191,6 +190,8 @@ pub const Message = packed struct {
}; };
test "testing variable length zero copy struct" { test "testing variable length zero copy struct" {
{
// Relay test
const payload = "Hello darkness my old friend"; const payload = "Hello darkness my old friend";
var msg_bytes: [try Message.calcSize(.relay, payload.len)]u8 align(@alignOf(Message)) = undefined; var msg_bytes: [try Message.calcSize(.relay, payload.len)]u8 align(@alignOf(Message)) = undefined;
@@ -210,15 +211,13 @@ test "testing variable length zero copy struct" {
} }
{ {
const bytes = msg.asBytes();
// Print the message as hex using the network byte order // Print the message as hex using the network byte order
try msg.networkFromNativeEndian(); try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because // We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian. // it would have returned an error from networkFromNativeEndian.
defer msg.nativeFromNetworkEndian() catch unreachable; defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("network bytes: {x}\n", .{bytes}); std.debug.print("relay network bytes: {x}\n", .{msg_bytes});
std.debug.print("bytes len: {d}\n", .{bytes.len}); std.debug.print("bytes len: {d}\n", .{msg_bytes.len});
} }
if (false) { if (false) {
@@ -227,6 +226,38 @@ test "testing variable length zero copy struct" {
} }
try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes())); try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes()));
}
{
// Connection test
const payload = "Hello darkness my old friend";
var msg_bytes: [try Message.calcSize(.connection, payload.len)]u8 align(@alignOf(Message)) = undefined;
// Create a view of the byte slice as a Message
const msg: *Message = .init(.connection, &msg_bytes);
{
// Initializing connection header values
const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = 1;
connection.dest_port = 2;
connection.seq_num = 3;
connection.msg_id = 4;
connection.reserved = 5;
connection.options = @bitCast(@as(u8, 6));
@memcpy(connection.getPayload(), payload);
}
{
// Print the message as hex using the network byte order
try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian.
defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("connection network bytes: {x}\n", .{msg_bytes});
std.debug.print("bytes len: {d}\n", .{msg_bytes.len});
}
}
} }
const std = @import("std"); const std = @import("std");