18 Commits

Author SHA1 Message Date
56e72928c6 fix use after free 2025-05-10 21:46:53 -04:00
a80c9abfe7 Attempt to base64 encode the connection payload
For some reason I am still getting this:

2025/05/10 16:37:06 Error decoding message: SGVsbG8gZGFya25lc3MgbXkgb2xkIGZyaWVuZA==::53475673624738675a4746796132356c63334d6762586b676232786b49475a79615756755a413d3daaaa
2025-05-10 21:46:53 -04:00
245dab4909 Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
2025-05-10 21:46:53 -04:00
cde5c3626c 2025-05-10 21:46:53 -04:00
e84d1a2300 2025-05-10 21:46:53 -04:00
1b7d9bbb1a Remove bytesAsValueUnchecked
Callers can instead use std.mem.bytesAsValue directly.
2025-05-10 21:46:53 -04:00
1512ec1a86 Cleanup asBytes and test it 2025-05-10 21:46:53 -04:00
f1dce257be Simplify init interface 2025-05-10 21:46:53 -04:00
bcab1e4d00 2025-05-10 21:46:53 -04:00
0e8f016978 Align the bytes instead of the struct 2025-05-10 21:46:53 -04:00
fc53e87389 2025-05-10 21:46:53 -04:00
cbf554e853 2025-05-10 21:46:53 -04:00
775212013f 2025-05-10 21:46:53 -04:00
339ac5cfe5 2025-05-10 21:46:53 -04:00
eacfaffb6b 2025-05-10 21:46:53 -04:00
1731b2e643 2025-05-10 21:46:53 -04:00
dae66a0039 Starting real connections 2025-05-10 21:46:53 -04:00
683a2015b0 Use FAIL as the default dest if unable to parse 2025-04-27 18:03:06 -04:00
9 changed files with 243 additions and 271 deletions

View File

@@ -1,5 +1,4 @@
const std = @import("std"); const std = @import("std");
const Step = std.Build.Step;
// Although this function looks imperative, note that its job is to // Although this function looks imperative, note that its job is to
// declaratively construct a build graph that will be executed by an external // declaratively construct a build graph that will be executed by an external
@@ -34,34 +33,18 @@ pub fn build(b: *std.Build) void {
}); });
lib_mod.addImport("network", b.dependency("network", .{}).module("network")); lib_mod.addImport("network", b.dependency("network", .{}).module("network"));
lib_mod.addImport("gatorcat", b.dependency("gatorcat", .{}).module("gatorcat"));
exe_mod.addImport("zaprus", lib_mod); exe_mod.addImport("zaprus", lib_mod);
exe_mod.addImport("clap", b.dependency("clap", .{}).module("clap")); exe_mod.addImport("clap", b.dependency("clap", .{}).module("clap"));
const static_lib = b.addLibrary(.{ const lib = b.addLibrary(.{
.linkage = .static, .linkage = .static,
.name = "zaprus", .name = "zaprus",
.root_module = lib_mod, .root_module = lib_mod,
}); });
b.installArtifact(static_lib); b.installArtifact(lib);
const dynamic_lib = b.addLibrary(.{
.linkage = .dynamic,
.name = "zaprus",
.root_module = lib_mod,
});
b.installArtifact(dynamic_lib);
// C Headers
const c_header = b.addInstallFileWithDir(
b.path("include/zaprus.h"),
.header,
"zaprus.h",
);
b.getInstallStep().dependOn(&c_header.step);
// This creates another `std.Build.Step.Compile`, but this one builds an executable // This creates another `std.Build.Step.Compile`, but this one builds an executable
// rather than a static library. // rather than a static library.

View File

@@ -44,6 +44,10 @@
.url = "git+https://github.com/Hejsil/zig-clap?ref=0.10.0#e47028deaefc2fb396d3d9e9f7bd776ae0b2a43a", .url = "git+https://github.com/Hejsil/zig-clap?ref=0.10.0#e47028deaefc2fb396d3d9e9f7bd776ae0b2a43a",
.hash = "clap-0.10.0-oBajB434AQBDh-Ei3YtoKIRxZacVPF1iSwp3IX_ZB8f0", .hash = "clap-0.10.0-oBajB434AQBDh-Ei3YtoKIRxZacVPF1iSwp3IX_ZB8f0",
}, },
.gatorcat = .{
.url = "git+https://github.com/kj4tmp/gatorcat#bb1847f6c95852e7a0ec8c07870a948c171d5f98",
.hash = "gatorcat-0.3.2-WcrpTf1mBwDrmPaIhKCfLJO064v8Sjjn7DBq4CKZSgHH",
},
}, },
.paths = .{ .paths = .{
"build.zig", "build.zig",

View File

@@ -1,24 +0,0 @@
// client
int zaprus_init(void);
int zaprus_deinit(void);
int zaprus_send_relay(const char* payload, usize len, char[4] dest);
int zaprus_send_initial_connection(const char* payload, usize len, uint16_t initial_port);
struct SaprusMessage* zaprus_connect(const char* payload, usize len);
// message
struct SaprusMessage {
};
// ptr should be freed by the caller.
int zaprus_message_to_bytes(struct SaprusMessage msg, char** ptr, usize* len);
// Return value should be destroyed with zaprus_message_deinit.
struct SaprusMessage* zaprus_message_from_bytes(const char* bytes, usize len);
void zaprus_message_deinit(struct SaprusMessage* msg);

View File

@@ -1,3 +1,6 @@
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
var rand: ?Random = null; var rand: ?Random = null;
pub fn init() !void { pub fn init() !void {
@@ -14,9 +17,14 @@ pub fn deinit() void {
network.deinit(); network.deinit();
} }
fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocator) !void { fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void {
const msg_bytes = try msg.toBytes(allocator); if (false) {
defer allocator.free(msg_bytes); var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev
defer foo.deinit();
}
const msg_bytes = msg.asBytes();
try msg.networkFromNativeEndian();
defer msg.nativeFromNetworkEndian() catch unreachable;
var sock = try network.Socket.create(.ipv4, .udp); var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close(); defer sock.close();
@@ -40,14 +48,22 @@ fn broadcastSaprusMessage(msg: SaprusMessage, udp_port: u16, allocator: Allocato
} }
pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void { pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void {
const msg = SaprusMessage{ const msg_bytes = try allocator.alignedAlloc(
.relay = .{ u8,
.header = .{ .dest = dest }, @alignOf(SaprusMessage),
.payload = payload, try SaprusMessage.lengthForPayloadLength(
}, .relay,
}; base64Enc.calcSize(payload.len),
),
);
defer allocator.free(msg_bytes);
const msg: *SaprusMessage = .init(.relay, msg_bytes);
try broadcastSaprusMessage(msg, 8888, allocator); const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = dest;
_ = base64Enc.encode(relay.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
} }
fn randomPort() u16 { fn randomPort() u16 {
@@ -59,31 +75,36 @@ fn randomPort() u16 {
return p; return p;
} }
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !SaprusMessage { pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage {
const dest_port = randomPort(); const dest_port = randomPort();
const msg = SaprusMessage{ const msg_bytes = try allocator.alignedAlloc(
.connection = .{ u8,
.header = .{ @alignOf(SaprusMessage),
.src_port = initial_port, try SaprusMessage.lengthForPayloadLength(
.dest_port = dest_port, .connection,
}, base64Enc.calcSize(payload.len),
.payload = payload, ),
}, );
};
try broadcastSaprusMessage(msg, 8888, allocator); const msg: *SaprusMessage = .init(.connection, msg_bytes);
const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = initial_port;
connection.dest_port = dest_port;
_ = base64Enc.encode(connection.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
return msg; return msg;
} }
pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusMessage { pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
var initial_port: u16 = 0; var initial_port: u16 = 0;
if (rand) |r| { if (rand) |r| {
initial_port = r.intRangeAtMost(u16, 1024, 65000); initial_port = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable; } else unreachable;
var initial_conn_res: ?SaprusMessage = null; var initial_conn_res: ?*SaprusMessage = null;
errdefer if (initial_conn_res) |c| c.deinit(allocator);
var sock = try network.Socket.create(.ipv4, .udp); var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close(); defer sock.close();
@@ -99,20 +120,26 @@ pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusMessage {
try sock.bind(bind_addr); try sock.bind(bind_addr);
const msg = try sendInitialConnection(payload, initial_port, allocator); const msg = try sendInitialConnection(payload, initial_port, allocator);
defer allocator.free(msg.asBytes());
var response_buf: [4096]u8 = undefined; var response_buf: [4096]u8 align(@alignOf(SaprusMessage)) = undefined;
_ = try sock.receive(&response_buf); // Ignore message that I sent. _ = try sock.receive(&response_buf); // Ignore message that I sent.
const len = try sock.receive(&response_buf); const len = try sock.receive(&response_buf);
initial_conn_res = try SaprusMessage.fromBytes(response_buf[0..len], allocator); std.debug.print("response bytes: {x}\n", .{response_buf[0..len]});
initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]);
// Complete handshake after awaiting response // Complete handshake after awaiting response
try broadcastSaprusMessage(msg, randomPort(), allocator); try broadcastSaprusMessage(msg, randomPort());
return initial_conn_res; if (false) {
return initial_conn_res.?;
}
return null;
} }
const SaprusMessage = @import("message.zig").Message; const SaprusMessage = @import("message.zig").Message;
const SaprusConnection = @import("Connection.zig");
const std = @import("std"); const std = @import("std");
const Random = std.Random; const Random = std.Random;
@@ -120,5 +147,6 @@ const posix = std.posix;
const mem = std.mem; const mem = std.mem;
const network = @import("network"); const network = @import("network");
const gcat = @import("gatorcat");
const Allocator = mem.Allocator; const Allocator = mem.Allocator;

0
src/Connection.zig Normal file
View File

View File

@@ -1,56 +0,0 @@
// client
export fn zaprus_init() c_int {
SaprusClient.init() catch return 1;
return 0;
}
export fn zaprus_deinit() c_int {
SaprusClient.deinit();
return 0;
}
export fn zaprus_send_relay(payload: [*]const u8, len: usize, dest: [4]u8) c_int {
SaprusClient.sendRelay(payload[0..len], dest, allocator) catch return 1;
return 0;
}
export fn zaprus_send_initial_connection(payload: [*]const u8, len: usize, initial_port: u16) c_int {
SaprusClient.sendInitialConnection(payload[0..len], initial_port, allocator) catch return 1;
return 0;
}
export fn zaprus_connect(payload: [*]const u8, len: usize) ?*SaprusMessage {
return SaprusClient.connect(payload[0..len], allocator) catch null;
}
// message
/// ptr should be freed by the caller.
export fn zaprus_message_to_bytes(msg: SaprusMessage, ptr: *[*]u8, len: *usize) c_int {
const bytes = msg.toBytes(allocator) catch return 1;
ptr.* = bytes[0..].*;
len.* = bytes.len;
return 0;
}
/// Return value should be destroyed with zaprus_message_deinit.
export fn zaprus_message_from_bytes(bytes: [*]const u8, len: usize) ?*SaprusMessage {
return SaprusMessage.fromBytes(bytes[0..len], allocator) catch null;
}
export fn zaprus_message_deinit(msg: *SaprusMessage) void {
msg.deinit(allocator);
}
const std = @import("std");
const zaprus = @import("./root.zig");
const SaprusClient = zaprus.Client;
const SaprusMessage = zaprus.Message;
const allocator = std.heap.c_allocator;
test {
std.testing.refAllDeclsRecursively(@This());
}

View File

@@ -50,7 +50,7 @@ pub fn main() !void {
} }
if (res.args.relay) |r| { if (res.args.relay) |r| {
const dest = parseDest(res.args.dest) catch .{ 70, 70, 70, 70 }; const dest = parseDest(res.args.dest);
try SaprusClient.sendRelay( try SaprusClient.sendRelay(
if (r.len > 0) r else "Hello darkness my old friend", if (r.len > 0) r else "Hello darkness my old friend",
dest, dest,
@@ -59,23 +59,17 @@ pub fn main() !void {
// std.debug.print("Sent: {s}\n", .{r}); // std.debug.print("Sent: {s}\n", .{r});
return; return;
} else if (res.args.connect) |c| { } else if (res.args.connect) |c| {
const conn_res: ?SaprusMessage = SaprusClient.connect(if (c.len > 0) c else "Hello darkness my old friend", gpa) catch |err| switch (err) { _ = SaprusClient.connect(if (c.len > 0) c else "Hello darkness my old friend", gpa) catch |err| switch (err) {
error.WouldBlock => null, error.WouldBlock => null,
else => return err, else => return err,
}; };
defer if (conn_res) |r| r.deinit(gpa);
if (conn_res) |r| {
std.debug.print("{s}\n", .{r.connection.payload});
} else {
std.debug.print("No response from connection request\n", .{});
}
return; return;
} }
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{}); return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});
} }
fn parseDest(in: ?[]const u8) ![4]u8 { fn parseDest(in: ?[]const u8) [4]u8 {
if (in) |dest| { if (in) |dest| {
if (dest.len <= 4) { if (dest.len <= 4) {
var res: [4]u8 = @splat(0); var res: [4]u8 = @splat(0);
@@ -83,10 +77,10 @@ fn parseDest(in: ?[]const u8) ![4]u8 {
return res; return res;
} }
const addr = try std.net.Ip4Address.parse(dest, 0); const addr = std.net.Ip4Address.parse(dest, 0) catch return "FAIL".*;
return @bitCast(addr.sa.addr); return @bitCast(addr.sa.addr);
} }
return .{ 70, 70, 70, 70 }; return "zap\x00".*;
} }
const builtin = @import("builtin"); const builtin = @import("builtin");

View File

@@ -1,6 +1,3 @@
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
/// Type tag for Message union. /// Type tag for Message union.
/// This is the first value in the actual packet sent over the network. /// This is the first value in the actual packet sent over the network.
pub const PacketType = enum(u16) { pub const PacketType = enum(u16) {
@@ -23,153 +20,195 @@ pub const ConnectionOptions = packed struct(u8) {
opt8: bool = false, opt8: bool = false,
}; };
pub const Error = error{ pub const MessageTypeError = error{
NotImplementedSaprusType, NotImplementedSaprusType,
UnknownSaprusType, UnknownSaprusType,
}; };
pub const MessageParseError = MessageTypeError || error{
InvalidMessage,
};
// ZERO COPY STUFF
// &payload could be a void value that is treated as a pointer to a [*]u8
/// All Saprus messages /// All Saprus messages
pub const Message = union(PacketType) { pub const Message = packed struct {
pub const Relay = struct { const Relay = packed struct {
pub const Header = packed struct { dest: @Vector(4, u8),
dest: @Vector(4, u8), payload: void,
};
header: Header,
payload: []const u8,
};
pub const Connection = struct {
pub const Header = packed struct {
src_port: u16, // random number > 1024
dest_port: u16, // random number > 1024
seq_num: u32 = 0,
msg_id: u32 = 0,
reserved: u8 = 0,
options: ConnectionOptions = .{},
};
header: Header,
payload: []const u8,
};
relay: Relay,
file_transfer: void, // unimplemented
connection: Connection,
/// Should be called for any Message that was declared using a function that you pass an allocator to. pub fn getPayload(self: *align(1) Relay) []u8 {
pub fn deinit(self: Message, allocator: Allocator) void { const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
switch (self) { return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Relay)];
.relay => |r| allocator.free(r.payload), }
.connection => |c| allocator.free(c.payload), };
const Connection = packed struct {
src_port: u16, // random number > 1024
dest_port: u16, // random number > 1024
seq_num: u32 = 0,
msg_id: u32 = 0,
reserved: u8 = 0,
options: ConnectionOptions = .{},
payload: void,
pub fn getPayload(self: *align(1) Connection) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Connection)];
}
fn nativeFromNetworkEndian(self: *align(1) Connection) void {
self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port);
self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id);
}
fn networkFromNativeEndian(self: *align(1) Connection) void {
self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port);
self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = nativeToBig(@TypeOf(self.msg_id), self.msg_id);
}
};
const Self = @This();
const SelfBytes = []align(@alignOf(Self)) u8;
type: PacketType,
length: u16,
bytes: void = {},
/// Takes a byte slice, and returns a Message struct backed by the slice.
/// This properly initializes the top level headers within the slice.
pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self {
std.debug.assert(bytes.len >= @sizeOf(Self));
const res: *Self = @ptrCast(bytes.ptr);
res.type = @"type";
res.length = @intCast(bytes.len - @sizeOf(Self));
return res;
}
pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
std.debug.assert(payload_len < std.math.maxInt(u16));
const header_size = @sizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType,
});
return @intCast(payload_len + @sizeOf(Self) + header_size);
}
fn getRelay(self: *Self) *align(1) Relay {
return std.mem.bytesAsValue(Relay, &self.bytes);
}
fn getConnection(self: *Self) *align(1) Connection {
return std.mem.bytesAsValue(Connection, &self.bytes);
}
pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {
relay: *align(1) Relay,
file_transfer: void,
connection: *align(1) Connection,
}) {
return switch (self.type) {
.relay => .{ .relay = self.getRelay() },
.connection => .{ .connection = self.getConnection() },
.file_transfer => MessageTypeError.NotImplementedSaprusType,
else => MessageTypeError.UnknownSaprusType,
};
}
pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {
self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = bigToNative(@TypeOf(self.length), self.length);
errdefer {
// If the payload specific headers fail, revert the top level header values
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
switch (try self.getSaprusTypePayload()) {
.relay => {},
.connection => |*con| con.*.nativeFromNetworkEndian(),
// We know other values are unreachable,
// because they would have returned an error from the switch condition.
else => unreachable, else => unreachable,
} }
} }
fn toBytesAux( pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {
header: anytype, try switch (try self.getSaprusTypePayload()) {
payload: []const u8, .relay => {},
buf: *std.ArrayList(u8), .connection => |*con| con.*.networkFromNativeEndian(),
allocator: Allocator, .file_transfer => MessageTypeError.NotImplementedSaprusType,
) !void { else => MessageTypeError.UnknownSaprusType,
const Header = @TypeOf(header); };
// Create a growable string to store the base64 bytes in. self.type = @enumFromInt(nativeToBig(
// Doing this first so I can use the length of the encoded bytes for the length field. @typeInfo(@TypeOf(self.type)).@"enum".tag_type,
var payload_list = std.ArrayList(u8).init(allocator); @intFromEnum(self.type),
defer payload_list.deinit(); ));
const buf_w = payload_list.writer(); self.length = nativeToBig(@TypeOf(self.length), self.length);
// Write the payload bytes as base64 to the growable string.
try base64Enc.encodeWriter(buf_w, payload);
// At this point, payload_list contains the base64 encoded payload.
// Add the payload length to the output buf.
try buf.*.appendSlice(
asBytes(&nativeToBig(u16, @intCast(payload_list.items.len + @bitSizeOf(Header) / 8))),
);
// Add the header bytes to the output buf.
var header_buf: [@sizeOf(Header)]u8 = undefined;
var header_buf_stream = std.io.fixedBufferStream(&header_buf);
try header_buf_stream.writer().writeStructEndian(header, .big);
// Add the exact number of bits in the header without padding.
try buf.*.appendSlice(header_buf[0 .. @bitSizeOf(Header) / 8]);
try buf.*.appendSlice(payload_list.items);
} }
/// Caller is responsible for freeing the returned bytes. /// Deprecated.
pub fn toBytes(self: Message, allocator: Allocator) ![]u8 { /// If I need the bytes, I should just pass around the slice that is backing this to begin with.
// Create a growable list of bytes to store the output in. pub fn asBytes(self: *Self) SelfBytes {
var buf = std.ArrayList(u8).init(allocator); const size = @sizeOf(Self) + self.length;
errdefer buf.deinit(); return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
// Start with writing the message type, which is the first 16 bits of every Saprus message.
try buf.appendSlice(asBytes(&nativeToBig(u16, @intFromEnum(self))));
// Write the proper header and payload for the given packet type.
switch (self) {
.relay => |r| try toBytesAux(r.header, r.payload, &buf, allocator),
.connection => |c| try toBytesAux(c.header, c.payload, &buf, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
}
// Collect the growable list as a slice and return it.
return buf.toOwnedSlice();
}
fn fromBytesAux(
comptime packet: PacketType,
len: u16,
r: std.io.FixedBufferStream([]const u8).Reader,
allocator: Allocator,
) !Message {
const Header = @field(@FieldType(Message, @tagName(packet)), "Header");
// Read the header for the current message type.
var header_bytes: [@sizeOf(Header)]u8 = undefined;
_ = try r.read(header_bytes[0 .. @bitSizeOf(Header) / 8]);
var header_stream = std.io.fixedBufferStream(&header_bytes);
const header = try header_stream.reader().readStructEndian(Header, .big);
// Read the base64 bytes into a list to be able to call the decoder on it.
const payload_buf = try allocator.alloc(u8, len - @bitSizeOf(Header) / 8);
defer allocator.free(payload_buf);
_ = try r.readAll(payload_buf);
// Create a buffer to store the payload in, and decode the base64 bytes into the payload field.
const payload = try allocator.alloc(u8, try base64Dec.calcSizeForSlice(payload_buf));
try base64Dec.decode(payload, payload_buf);
// Return the type of Message specified by the `packet` argument.
return @unionInit(Message, @tagName(packet), .{
.header = header,
.payload = payload,
});
}
/// Caller is responsible for calling .deinit on the returned value.
pub fn fromBytes(bytes: []const u8, allocator: Allocator) !Message {
var s = std.io.fixedBufferStream(bytes);
const r = s.reader();
// Read packet type
const packet_type = @as(PacketType, @enumFromInt(try r.readInt(u16, .big)));
// Read the length of the header + base64 encoded payload.
const len = try r.readInt(u16, .big);
switch (packet_type) {
.relay => return fromBytesAux(.relay, len, r, allocator),
.connection => return fromBytesAux(.connection, len, r, allocator),
.file_transfer => return Error.NotImplementedSaprusType,
else => return Error.UnknownSaprusType,
}
} }
}; };
test "testing variable length zero copy struct" {
const gpa = std.testing.allocator;
const payload = "Hello darkness my old friend";
// Create a view of the byte slice as a Message
const msg: *Message = try .init(gpa, .relay, payload.len);
defer msg.deinit(gpa);
{
// Set the message values
{
// These are both set by the init call.
// msg.type = .relay;
// msg.length = payload_len;
}
const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = .{ 1, 2, 3, 4 };
@memcpy(relay.getPayload(), payload);
}
{
const bytes = msg.asBytes();
// Print the message as hex using the network byte order
try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian.
defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("network bytes: {x}\n", .{bytes});
std.debug.print("bytes len: {d}\n", .{bytes.len});
}
if (false) {
// Illegal behavior
std.debug.print("{any}\n", .{(try msg.getSaprusTypePayload()).connection});
}
try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes()));
}
const std = @import("std"); const std = @import("std");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const asBytes = std.mem.asBytes; const asBytes = std.mem.asBytes;
const nativeToBig = std.mem.nativeToBig; const nativeToBig = std.mem.nativeToBig;
const bigToNative = std.mem.bigToNative;
test "Round trip Relay toBytes and fromBytes" { test "Round trip Relay toBytes and fromBytes" {
const gpa = std.testing.allocator; const gpa = std.testing.allocator;
@@ -209,3 +248,7 @@ test "Round trip Connection toBytes and fromBytes" {
try std.testing.expectEqualDeep(msg, from_bytes); try std.testing.expectEqualDeep(msg, from_bytes);
} }
test {
std.testing.refAllDeclsRecursive(@This());
}

View File

@@ -1,4 +1,4 @@
pub const Client = @import("Client.zig"); pub const Client = @import("Client.zig");
pub usingnamespace @import("message.zig"); pub const Connection = @import("Connection.zig");
pub usingnamespace @import("c_api.zig"); pub usingnamespace @import("message.zig");