More robust parsing and error propagation

This commit is contained in:
2026-01-08 21:41:49 -05:00
parent d8488fde49
commit ed99115969
2 changed files with 117 additions and 59 deletions

View File

@@ -197,28 +197,49 @@ fn handleConnection(
defer client.recv_queue_write_lock.unlock(io); defer client.recv_queue_write_lock.unlock(io);
_ = try client.from_client.take(2); _ = try client.from_client.take(2);
try client.recv_queue.putAll(io, "PONG\r\n"); try client.recv_queue.putAll(io, "PONG\r\n");
// try client.send(io, "PONG\r\n");
}, },
.PUB => { .PUB => {
@branchHint(.likely); @branchHint(.likely);
// log.debug("received a pub msg", .{}); // log.debug("received a pub msg", .{});
try server.publishMessage(io, server_allocator, &client, .@"pub"); server.publishMessage(io, server_allocator, &client, .@"pub") catch |err| switch (err) {
error.WriteFailed => return writer.err.?,
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
};
}, },
.HPUB => { .HPUB => {
@branchHint(.likely); @branchHint(.likely);
try server.publishMessage(io, server_allocator, &client, .hpub); server.publishMessage(io, server_allocator, &client, .hpub) catch |err| switch (err) {
error.WriteFailed => return writer.err.?,
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
};
}, },
.SUB => { .SUB => {
try server.subscribe(io, server_allocator, &client, id); server.subscribe(io, server_allocator, &client, id) catch |err| switch (err) {
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
};
}, },
.UNSUB => { .UNSUB => {
try server.unsubscribe(io, server_allocator, client, id); server.unsubscribe(io, server_allocator, client, id) catch |err| switch (err) {
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
};
}, },
.CONNECT => { .CONNECT => {
if (client.connect) |*current| { if (client.connect) |*current| {
current.deinit(server_allocator); current.deinit(server_allocator);
} }
client.connect = try parse.connect(server_allocator, client.from_client); client.connect = parse.connect(server_allocator, client.from_client) catch |err| switch (err) {
error.ReadFailed => return reader.err.?,
error.EndOfStream => return error.ClientDisconnected,
else => |e| return e,
};
}, },
else => |e| { else => |e| {
panic("Unimplemented message: {any}\n", .{e}); panic("Unimplemented message: {any}\n", .{e});
@@ -291,11 +312,6 @@ fn publishMessage(
.hpub => hpubmsg.@"pub", .hpub => hpubmsg.@"pub",
}; };
// const subject = switch (pub_or_hpub) {
// .PUB => |pb| pb.subject,
// .HPUB => |hp| hp.@"pub".subject,
// else => unreachable,
// };
try server.subs_lock.lock(io); try server.subs_lock.lock(io);
defer server.subs_lock.unlock(io); defer server.subs_lock.unlock(io);
var published_queue_groups: ArrayList([]const u8) = .empty; var published_queue_groups: ArrayList([]const u8) = .empty;
@@ -380,6 +396,7 @@ fn subscribe(
.queue_lock = &client.recv_queue_write_lock, .queue_lock = &client.recv_queue_write_lock,
.queue = client.recv_queue, .queue = client.recv_queue,
}); });
log.debug("Client {d} subscribed to {s}", .{ id, msg.subject });
} }
fn unsubscribe( fn unsubscribe(
@@ -397,8 +414,10 @@ fn unsubscribe(
const i = len - from_end - 1; const i = len - from_end - 1;
const sub = server.subscriptions.items[i]; const sub = server.subscriptions.items[i];
if (sub.client_id == id and eql(u8, sub.sid, msg.sid)) { if (sub.client_id == id and eql(u8, sub.sid, msg.sid)) {
log.debug("Client {d} unsubscribed from {s}", .{ id, server.subscriptions.items[i].subject });
sub.deinit(gpa); sub.deinit(gpa);
_ = server.subscriptions.swapRemove(i); _ = server.subscriptions.swapRemove(i);
break;
} }
} }
} }

View File

@@ -45,6 +45,8 @@ pub fn control(in: *Reader) !message.Control {
break :blk min_len; break :blk min_len;
}; };
std.debug.assert(in.buffer.len >= longest_ctrl); std.debug.assert(in.buffer.len >= longest_ctrl);
// Wait until at least the enough text to parse the shortest control value is available
try in.fill(3);
while (true) { while (true) {
var iter = std.mem.tokenizeAny(u8, in.buffered(), " \t\r"); var iter = std.mem.tokenizeAny(u8, in.buffered(), " \t\r");
if (iter.next()) |str| { if (iter.next()) |str| {
@@ -55,6 +57,7 @@ pub fn control(in: *Reader) !message.Control {
return error.InvalidControl; return error.InvalidControl;
} }
} }
log.debug("filling more in control.", .{});
try in.fillMore(); try in.fillMore();
} }
} }
@@ -146,7 +149,7 @@ pub fn @"pub"(in: *Reader) !Message.Pub {
if (iter.next()) |bytes_str| { if (iter.next()) |bytes_str| {
const bytes = try parseUnsigned(usize, bytes_str, 10); const bytes = try parseUnsigned(usize, bytes_str, 10);
if (in.buffered()[iter.index] == '\r') { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
if (in.buffered().len < iter.index + bytes + 4) { if (in.buffered().len < iter.index + bytes + 4) {
try in.fill(iter.index + bytes + 4); try in.fill(iter.index + bytes + 4);
// Fill may shift buffer, so we have to retokenize it. // Fill may shift buffer, so we have to retokenize it.
@@ -287,7 +290,7 @@ pub fn sub(in: *Reader) !Message.Sub {
const queue_group = second; const queue_group = second;
if (iter.next()) |sid| { if (iter.next()) |sid| {
if (in.buffered()[iter.index] == '\r') { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
if (in.buffered().len < iter.index + 2) { if (in.buffered().len < iter.index + 2) {
try in.fill(iter.index + 2); try in.fill(iter.index + 2);
// Fill may shift buffer, so we have to retokenize it. // Fill may shift buffer, so we have to retokenize it.
@@ -380,9 +383,10 @@ pub fn unsub(in: *Reader) !Message.Unsub {
// See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1 // See: https://docs.nats.io/reference/reference-protocols/nats-protocol#syntax-1
while (true) { while (true) {
var iter = std.mem.tokenizeAny(u8, in.buffered(), " \t\r"); var iter = std.mem.tokenizeAny(u8, in.buffered(), " \t\r\n");
if (iter.next()) |sid| { if (iter.next()) |sid| {
if (in.buffered().len > iter.index) {
if (in.buffered()[iter.index] == '\r') { if (in.buffered()[iter.index] == '\r') {
if (in.buffered().len < iter.index + 2) { if (in.buffered().len < iter.index + 2) {
try in.fill(iter.index + 2); try in.fill(iter.index + 2);
@@ -398,7 +402,7 @@ pub fn unsub(in: *Reader) !Message.Unsub {
}; };
} }
if (iter.next()) |max_msgs_str| { if (iter.next()) |max_msgs_str| {
if (in.buffered()[iter.index] == '\r') { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
const max_msgs = try parseUnsigned(usize, max_msgs_str, 10); const max_msgs = try parseUnsigned(usize, max_msgs_str, 10);
if (in.buffered().len < iter.index + 2) { if (in.buffered().len < iter.index + 2) {
@@ -417,8 +421,23 @@ pub fn unsub(in: *Reader) !Message.Unsub {
} }
} }
} }
}
try in.fillMore(); in.fillMore() catch |err| switch (err) {
error.EndOfStream => {
iter.reset();
const sid = iter.next() orelse return error.EndOfStream;
const max_msgs = if (iter.next()) |max_msgs_str| blk: {
log.debug("max_msgs: {any}", .{max_msgs_str});
break :blk try parseUnsigned(usize, max_msgs_str, 10);
} else null;
return .{
.sid = sid,
.max_msgs = max_msgs,
};
},
else => |e| return e,
};
} }
} }
@@ -478,6 +497,20 @@ test unsub {
try unsub(&in.interface), try unsub(&in.interface),
); );
} }
{
var buf: [64]u8 = undefined;
var in: std.testing.Reader = .init(&buf, &.{
.{ .buffer = " 1\r" },
.{ .buffer = "\n" },
});
try std.testing.expectEqualDeep(
Message.Unsub{
.sid = "1",
.max_msgs = null,
},
try unsub(&in.interface),
);
}
} }
/// The return value is owned by the reader passed to this function. /// The return value is owned by the reader passed to this function.
@@ -521,8 +554,7 @@ pub fn hpub(in: *Reader) !Message.HPub {
const reply_to = second; const reply_to = second;
const header_bytes_str = third; const header_bytes_str = third;
if (iter.next()) |total_bytes_str| { if (iter.next()) |total_bytes_str| {
if (in.buffered().len > iter.index) { if (in.buffered().len > iter.index and in.buffered()[iter.index] == '\r') {
if (in.buffered()[iter.index] == '\r') {
const header_bytes = try parseUnsigned(usize, header_bytes_str, 10); const header_bytes = try parseUnsigned(usize, header_bytes_str, 10);
const total_bytes = try parseUnsigned(usize, total_bytes_str, 10); const total_bytes = try parseUnsigned(usize, total_bytes_str, 10);
@@ -547,7 +579,6 @@ pub fn hpub(in: *Reader) !Message.HPub {
} }
} }
} }
}
try in.fillMore(); try in.fillMore();
} }
@@ -593,7 +624,12 @@ test hpub {
// TODO: more tests // TODO: more tests
} }
pub fn connect(alloc: Allocator, in: *Reader) !Message.Connect { pub fn connect(alloc: Allocator, in: *Reader) error{
EndOfStream,
ReadFailed,
OutOfMemory,
InvalidStream,
}!Message.Connect {
// for storing the json string // for storing the json string
var connect_string_writer_allocating: AllocatingWriter = .init(alloc); var connect_string_writer_allocating: AllocatingWriter = .init(alloc);
defer connect_string_writer_allocating.deinit(); defer connect_string_writer_allocating.deinit();
@@ -607,18 +643,21 @@ pub fn connect(alloc: Allocator, in: *Reader) !Message.Connect {
try in.discardAll(1); // throw away space try in.discardAll(1); // throw away space
// Should read the next JSON object to the fixed buffer writer. // Should read the next JSON object to the fixed buffer writer.
_ = try in.streamDelimiter(connect_string_writer, '}'); _ = in.streamDelimiter(connect_string_writer, '}') catch |err| switch (err) {
try connect_string_writer.writeByte('}'); error.WriteFailed => return error.OutOfMemory,
else => |e| return e,
};
connect_string_writer.writeByte('}') catch return error.OutOfMemory;
try expectStreamBytes(in, "}\r\n"); // discard '}\r\n' try expectStreamBytes(in, "}\r\n"); // discard '}\r\n'
const connect_str = try connect_string_writer_allocating.toOwnedSlice(); const connect_str = try connect_string_writer_allocating.toOwnedSlice();
defer alloc.free(connect_str); defer alloc.free(connect_str);
const res = try std.json.parseFromSliceLeaky( const res = std.json.parseFromSliceLeaky(
Message.Connect, Message.Connect,
connect_allocator, connect_allocator,
connect_str, connect_str,
.{ .allocate = .alloc_always }, .{ .allocate = .alloc_always },
); ) catch return error.InvalidStream;
return res.dupe(alloc); return res.dupe(alloc);
} }