diff --git a/src/Client.zig b/src/Client.zig index a8170a5..1709cab 100644 --- a/src/Client.zig +++ b/src/Client.zig @@ -100,7 +100,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection { var connection: SaprusMessage = .{ .connection = .{ .src = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)), - .dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)), + .dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)), // Ignored, but good noise .seq = undefined, .id = undefined, .payload = payload, @@ -108,7 +108,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection { }; log.debug("Setting bpf filter to port {}", .{connection.connection.src}); - self.socket.attachSaprusPortFilter(connection.connection.src) catch |err| { + self.socket.attachSaprusPortFilter(null, connection.connection.src) catch |err| { log.err("Failed to set port filter: {t}", .{err}); return err; }; @@ -131,7 +131,17 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection { log.debug("Awaiting handshake response", .{}); // Ignore response from sentinel, just accept that we got one. - _ = try self.socket.receive(&res_buf); + const full_handshake_res = try self.socket.receive(&res_buf); + const handshake_res = saprusParse(full_handshake_res[42..]) catch |err| { + log.err("Parse error: {t}", .{err}); + return err; + }; + self.socket.attachSaprusPortFilter(handshake_res.connection.src, handshake_res.connection.dest) catch |err| { + log.err("Failed to set port filter: {t}", .{err}); + return err; + }; + connection.connection.dest = handshake_res.connection.src; + connection_bytes = connection.toBytes(&connection_buf); headers.udp.dst_port = udp_dest_port; headers.ip.id = rand.int(u16); @@ -153,6 +163,7 @@ pub fn connect(self: Client, io: Io, payload: []const u8) !SaprusConnection { const RawSocket = @import("./RawSocket.zig"); const SaprusMessage = @import("message.zig").Message; +const saprusParse = @import("message.zig").parse; const SaprusConnection = @import("Connection.zig"); const EthIpUdp = @import("./EthIpUdp.zig").EthIpUdp; diff --git a/src/Connection.zig b/src/Connection.zig index 90109af..bb81c38 100644 --- a/src/Connection.zig +++ b/src/Connection.zig @@ -28,25 +28,50 @@ pub fn init(socket: RawSocket, headers: EthIpUdp, connection: SaprusMessage) Con }; } -pub fn next(self: Connection, io: Io, buf: []u8) ![]const u8 { - _ = io; - log.debug("Awaiting connection message", .{}); - const res = try self.socket.receive(buf); - log.debug("Received {} byte connection message", .{res.len}); - const msg: SaprusMessage = try .parse(res[42..]); - const connection_res = msg.connection; +// 'p' as base64 +const pong = "cA=="; - log.debug("Payload was {s}", .{connection_res.payload}); +pub fn next(self: *Connection, io: Io, buf: []u8) ![]const u8 { + while (true) { + log.debug("Awaiting connection message", .{}); + const res = try self.socket.receive(buf); + log.debug("Received {} byte connection message", .{res.len}); + const msg = SaprusMessage.parse(res[42..]) catch |err| { + log.err("Failed to parse next message: {t}\n{x}\n{x}", .{ err, res[0..], res[42..] }); + return err; + }; - return connection_res.payload; + switch (msg) { + .connection => |con_res| { + if (try con_res.management()) |mgt| { + log.debug("Received management message {t}", .{mgt}); + switch (mgt) { + .ping => { + log.debug("Sending pong", .{}); + try self.send(io, .{ .management = true }, pong); + log.debug("Sent pong message", .{}); + }, + else => |m| log.debug("Received management message that I don't know how to handle: {t}", .{m}), + } + } else { + log.debug("Payload was {s}", .{con_res.payload}); + return con_res.payload; + } + }, + else => |m| { + std.debug.panic("Expected connection message, instead got {x}. This means there is an error with the BPF.", .{@intFromEnum(m)}); + }, + } + } } -pub fn send(self: *Connection, io: Io, buf: []const u8) !void { +pub fn send(self: *Connection, io: Io, options: SaprusMessage.Connection.Options, buf: []const u8) !void { const io_source: std.Random.IoSource = .{ .io = io }; const rand = io_source.interface(); log.debug("Sending connection message", .{}); + self.connection.connection.options = options; self.connection.connection.payload = buf; var connection_bytes_buf: [2048]u8 = undefined; const connection_bytes = self.connection.toBytes(&connection_bytes_buf); diff --git a/src/RawSocket.zig b/src/RawSocket.zig index 5732ce9..e43a8e4 100644 --- a/src/RawSocket.zig +++ b/src/RawSocket.zig @@ -133,7 +133,7 @@ pub fn receive(self: RawSocket, buf: []u8) ![]u8 { return buf[0..len]; } -pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void { +pub fn attachSaprusPortFilter(self: RawSocket, incoming_src_port: ?u16, incoming_dest_port: u16) !void { const BPF = std.os.linux.BPF; // BPF instruction structure for classic BPF const SockFilter = extern struct { @@ -149,11 +149,26 @@ pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void { }; // Build the filter program - const filter = [_]SockFilter{ + const filter = if (incoming_src_port) |inc_src| &[_]SockFilter{ // Load 2 bytes at offset 46 (absolute) .{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 46 }, + // Jump if equal to port (skip 1 if true, skip 0 if false) + .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 1, .jf = 0, .k = @as(u32, inc_src) }, + // Return 0x0 (fail) + .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0x0 }, + // Load 2 bytes at offset 48 (absolute) + .{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 48 }, // Jump if equal to port (skip 0 if true, skip 1 if false) - .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, port) }, + .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, incoming_dest_port) }, + // Return 0xffff (pass) + .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0xffff }, + // Return 0x0 (fail) + .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0x0 }, + } else &[_]SockFilter{ + // Load 2 bytes at offset 48 (absolute) + .{ .code = BPF.LD | BPF.H | BPF.ABS, .jt = 0, .jf = 0, .k = 48 }, + // Jump if equal to port (skip 0 if true, skip 1 if false) + .{ .code = BPF.JMP | BPF.JEQ | BPF.K, .jt = 0, .jf = 1, .k = @as(u32, incoming_dest_port) }, // Return 0xffff (pass) .{ .code = BPF.RET | BPF.K, .jt = 0, .jf = 0, .k = 0xffff }, // Return 0x0 (fail) @@ -161,8 +176,8 @@ pub fn attachSaprusPortFilter(self: RawSocket, port: u16) !void { }; const fprog = SockFprog{ - .len = filter.len, - .filter = &filter, + .len = @intCast(filter.len), + .filter = filter.ptr, }; // Attach filter to socket using setsockopt diff --git a/src/c_api.zig b/src/c_api.zig index 7f10c45..c2f3190 100644 --- a/src/c_api.zig +++ b/src/c_api.zig @@ -99,6 +99,6 @@ export fn zaprus_connection_send( const c: ?*zaprus.Connection = @ptrCast(@alignCast(connection)); const zc = c orelse return 1; - zc.send(io, payload[0..payload_len]) catch return 1; + zc.send(io, .{}, payload[0..payload_len]) catch return 1; return 0; } diff --git a/src/main.zig b/src/main.zig index 734357b..10dca33 100644 --- a/src/main.zig +++ b/src/main.zig @@ -191,6 +191,7 @@ pub fn main(init: std.process.Init) !void { error.SymLinkLoop, error.SystemResources, => blk: { + log.debug("Trying to execute command directly: {s}", .{connection_payload}); var argv_buf: [128][]const u8 = undefined; var argv: ArrayList([]const u8) = .initBuffer(&argv_buf); var payload_iter = std.mem.splitAny(u8, connection_payload, " \t\n"); @@ -229,7 +230,7 @@ pub fn main(init: std.process.Init) !void { error.EndOfStream => { cmd_output.print("{b64}", .{child_output_reader.interface.buffered()}) catch unreachable; if (cmd_output.end > 0) { - connection.send(init.io, cmd_output.buffered()) catch |e| { + connection.send(init.io, .{}, cmd_output.buffered()) catch |e| { log.debug("Failed to send connection chunk: {t}", .{e}); continue :next_message; }; @@ -238,7 +239,7 @@ pub fn main(init: std.process.Init) !void { }, }; cmd_output.print("{b64}", .{try child_output_reader.interface.takeArray(child_output_buf.len)}) catch unreachable; - connection.send(init.io, cmd_output.buffered()) catch |err| { + connection.send(init.io, .{}, cmd_output.buffered()) catch |err| { log.debug("Failed to send connection chunk: {t}", .{err}); continue :next_message; }; diff --git a/src/message.zig b/src/message.zig index e8ef268..0c1410d 100644 --- a/src/message.zig +++ b/src/message.zig @@ -169,11 +169,11 @@ const Connection = struct { seq: u32, id: u32, reserved: u8 = undefined, - options: Options = undefined, + options: Options = .{}, payload: []const u8, - /// Reserved option values. - /// Currently unused. + /// Option values. + /// Currently used! pub const Options = packed struct(u8) { opt1: bool = false, opt2: bool = false, @@ -182,7 +182,7 @@ const Connection = struct { opt5: bool = false, opt6: bool = false, opt7: bool = false, - opt8: bool = false, + management: bool = false, }; /// Asserts that buf is large enough to fit the connection message. @@ -199,6 +199,28 @@ const Connection = struct { out.writeAll(self.payload) catch unreachable; return out.buffered(); } + + /// If the current message is a management message, return what kind. + /// Else return null. + pub fn management(self: Connection) MessageParseError!?Management { + const b64_dec = std.base64.standard.Decoder; + if (self.options.management) { + var buf: [1]u8 = undefined; + _ = b64_dec.decode(&buf, self.payload) catch return error.InvalidMessage; + + return switch (buf[0]) { + 'P' => .ping, + 'p' => .pong, + else => error.UnknownSaprusType, + }; + } + return null; + } + + pub const Management = enum { + ping, + pong, + }; }; test "Round trip" { @@ -223,5 +245,5 @@ const Writer = std.Io.Writer; const Reader = std.Io.Reader; test { - std.testing.refAllDeclsRecursive(@This()); + std.testing.refAllDecls(@This()); } diff --git a/src/root.zig b/src/root.zig index c469021..aa78565 100644 --- a/src/root.zig +++ b/src/root.zig @@ -19,7 +19,6 @@ pub const Connection = @import("Connection.zig"); const msg = @import("message.zig"); -pub const PacketType = msg.PacketType; pub const MessageTypeError = msg.MessageTypeError; pub const MessageParseError = msg.MessageParseError; pub const Message = msg.Message;