Files
zaprus/src/message.zig
Robby Zambito 558f40213b Update to Saprus 0.2.1
Handle management messages instead of letting them bubble up through the
connection to the consumer.
Right now, this just means handling ping messages by sending a pong.

Also updated to follow the new handshake flow.
The sentinel will mirror the ports instead of matching them.

Now filters on the full source and dest ports, which are less likely to
have erroneous matches.
2026-02-01 19:16:22 -05:00

250 lines
8.6 KiB
Zig

// Copyright 2026 Robby Zambito
//
// This file is part of zaprus.
//
// Zaprus is free software: you can redistribute it and/or modify it under the
// terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// Zaprus is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
// A PARTICULAR PURPOSE. See the GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License along with
// Zaprus. If not, see <https://www.gnu.org/licenses/>.
pub const MessageTypeError = error{
NotImplementedSaprusType,
UnknownSaprusType,
};
pub const MessageParseError = MessageTypeError || error{
InvalidMessage,
};
const message = @This();
pub const Message = union(enum(u16)) {
relay: Message.Relay = 0x003C,
connection: Message.Connection = 0x00E9,
_,
pub const Relay = message.Relay;
pub const Connection = message.Connection;
pub fn toBytes(self: message.Message, buf: []u8) []u8 {
return switch (self) {
inline .relay, .connection => |m| m.toBytes(buf),
else => unreachable,
};
}
pub const parse = message.parse;
};
pub const relay_dest_len = 4;
pub fn parse(bytes: []const u8) MessageParseError!Message {
var in: Reader = .fixed(bytes);
const @"type" = in.takeEnum(std.meta.Tag(Message), .big) catch |err| switch (err) {
error.InvalidEnumTag => return error.UnknownSaprusType,
else => return error.InvalidMessage,
};
const checksum = in.takeArray(2) catch return error.InvalidMessage;
switch (@"type") {
.relay => {
const dest: Relay.Dest = .fromBytes(
in.takeArray(relay_dest_len) catch return error.InvalidMessage,
);
const payload = in.buffered();
return .{
.relay = .{
.dest = dest,
.checksum = checksum.*,
.payload = payload,
},
};
},
.connection => {
const src = in.takeInt(u16, .big) catch return error.InvalidMessage;
const dest = in.takeInt(u16, .big) catch return error.InvalidMessage;
const seq = in.takeInt(u32, .big) catch return error.InvalidMessage;
const id = in.takeInt(u32, .big) catch return error.InvalidMessage;
const reserved = in.takeByte() catch return error.InvalidMessage;
const options = in.takeStruct(Connection.Options, .big) catch return error.InvalidMessage;
const payload = in.buffered();
return .{
.connection = .{
.src = src,
.dest = dest,
.seq = seq,
.id = id,
.reserved = reserved,
.options = options,
.payload = payload,
},
};
},
else => return error.NotImplementedSaprusType,
}
}
test parse {
_ = try parse(&[_]u8{ 0x00, 0x3c, 0x00, 0x17, 0xac, 0x12, 0x01, 0x1e, 0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64 });
{
const expected: Message = .{
.connection = .{
.src = 12416,
.dest = 61680,
.seq = 0,
.id = 0,
.reserved = 0,
.options = @bitCast(@as(u8, 100)),
.payload = &[_]u8{ 0x69, 0x61, 0x6d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74 },
},
};
const actual = try parse(&[_]u8{ 0x00, 0xe9, 0x00, 0x18, 0x30, 0x80, 0xf0, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x69, 0x61, 0x6d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74 });
try std.testing.expectEqualDeep(expected, actual);
}
}
const Relay = struct {
dest: Dest,
checksum: [2]u8 = undefined,
payload: []const u8,
pub const Dest = struct {
bytes: [relay_dest_len]u8,
/// Asserts bytes is less than or equal to 4 bytes
pub fn fromBytes(bytes: []const u8) Dest {
var buf: [4]u8 = @splat(0);
std.debug.assert(bytes.len <= buf.len);
@memcpy(buf[0..bytes.len], bytes);
return .{ .bytes = buf };
}
};
pub fn init(dest: Dest, payload: []const u8) Relay {
return .{ .dest = dest, .payload = payload };
}
/// Asserts that buf is large enough to fit the relay message.
pub fn toBytes(self: Relay, buf: []u8) []u8 {
var out: Writer = .fixed(buf);
out.writeInt(u16, @intFromEnum(Message.relay), .big) catch unreachable;
out.writeInt(u16, @intCast(self.payload.len + 4), .big) catch unreachable; // Length field, but unread. Will switch to checksum
out.writeAll(&self.dest.bytes) catch unreachable;
out.writeAll(self.payload) catch unreachable;
return out.buffered();
}
test toBytes {
var buf: [1024]u8 = undefined;
const relay: Relay = .init(
.fromBytes(&.{ 172, 18, 1, 30 }),
// zig fmt: off
&[_]u8{
0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65,
0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64
},
// zig fmt: on
);
// zig fmt: off
var expected = [_]u8{
0x00, 0x3c, 0x00, 0x17, 0xac, 0x12, 0x01, 0x1e, 0x72,
0x65, 0x6d, 0x6f, 0x76, 0x65, 0x20, 0x65, 0x76, 0x65,
0x6e, 0x74, 0x20, 0x6c, 0x6f, 0x67, 0x67, 0x65, 0x64
};
// zig fmt: on
try expectEqualMessageBuffers(&expected, relay.toBytes(&buf));
}
};
const Connection = struct {
src: u16,
dest: u16,
seq: u32,
id: u32,
reserved: u8 = undefined,
options: Options = .{},
payload: []const u8,
/// Option values.
/// Currently used!
pub const Options = packed struct(u8) {
opt1: bool = false,
opt2: bool = false,
opt3: bool = false,
opt4: bool = false,
opt5: bool = false,
opt6: bool = false,
opt7: bool = false,
management: bool = false,
};
/// Asserts that buf is large enough to fit the connection message.
pub fn toBytes(self: Connection, buf: []u8) []u8 {
var out: Writer = .fixed(buf);
out.writeInt(u16, @intFromEnum(Message.connection), .big) catch unreachable;
out.writeInt(u16, @intCast(self.payload.len + 14), .big) catch unreachable; // Saprus length field, unread.
out.writeInt(u16, self.src, .big) catch unreachable;
out.writeInt(u16, self.dest, .big) catch unreachable;
out.writeInt(u32, self.seq, .big) catch unreachable;
out.writeInt(u32, self.id, .big) catch unreachable;
out.writeByte(self.reserved) catch unreachable;
out.writeStruct(self.options, .big) catch unreachable;
out.writeAll(self.payload) catch unreachable;
return out.buffered();
}
/// If the current message is a management message, return what kind.
/// Else return null.
pub fn management(self: Connection) MessageParseError!?Management {
const b64_dec = std.base64.standard.Decoder;
if (self.options.management) {
var buf: [1]u8 = undefined;
_ = b64_dec.decode(&buf, self.payload) catch return error.InvalidMessage;
return switch (buf[0]) {
'P' => .ping,
'p' => .pong,
else => error.UnknownSaprusType,
};
}
return null;
}
pub const Management = enum {
ping,
pong,
};
};
test "Round trip" {
{
const expected = [_]u8{ 0x0, 0xe9, 0x0, 0x15, 0x30, 0x80, 0xf0, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x64, 0x36, 0x3a, 0x3a, 0x64, 0x61, 0x74, 0x61 };
const msg = (try parse(&expected)).connection;
var res_buf: [expected.len + 1]u8 = undefined; // + 1 to test subslice result.
const res = msg.toBytes(&res_buf);
try expectEqualMessageBuffers(&expected, res);
}
}
// Skip checking the length / checksum, because that is undefined.
fn expectEqualMessageBuffers(expected: []const u8, actual: []const u8) !void {
try std.testing.expectEqualSlices(u8, expected[0..2], actual[0..2]);
try std.testing.expectEqualSlices(u8, expected[4..], actual[4..]);
}
const std = @import("std");
const Allocator = std.mem.Allocator;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
test {
std.testing.refAllDecls(@This());
}