This commit is contained in:
2025-12-02 22:37:50 -05:00
parent 4afdf32beb
commit 826da348a5
4 changed files with 244 additions and 87 deletions

View File

@@ -24,7 +24,7 @@ pub const ClientState = struct {
connect: Message.Connect, connect: Message.Connect,
in: *std.Io.Reader, in: *std.Io.Reader,
out: *std.Io.Writer, out: *std.Io.Writer,
) ClientState { ) !ClientState {
var res: ClientState = .{ var res: ClientState = .{
.id = id, .id = id,
.connect = connect, .connect = connect,
@@ -37,12 +37,10 @@ pub const ClientState = struct {
}; };
res.send_queue = .init(&res.send_queue_buffer); res.send_queue = .init(&res.send_queue_buffer);
res.recv_queue = .init(&res.recv_queue_buffer); res.recv_queue = .init(&res.recv_queue_buffer);
const write_task = io.async(processWrite, .{ &res, io, out }); // res.send_queue = .init(&.{});
// @compileLog(@TypeOf(write_task)); // res.recv_queue = .init(&.{});
const read_task = io.async(processRead, .{ &res, io, allocator, in }); res.write_task = try io.concurrent(processWrite, .{ &res, io, out });
// @compileLog(@TypeOf(read_task)); res.read_task = try io.concurrent(processRead, .{ &res, io, allocator, in });
res.write_task = write_task;
res.read_task = read_task;
return res; return res;
} }
@@ -53,13 +51,23 @@ pub const ClientState = struct {
out: *std.Io.Writer, out: *std.Io.Writer,
) void { ) void {
while (true) { while (true) {
const message = self.recv_queue.getOne(io) catch break; const message = self.recv_queue.getOne(io) catch continue;
switch (message) { switch (message) {
.@"+ok" => writeOk(out) catch break, .@"+ok" => {
.pong => writePong(out) catch break, writeOk(out) catch break;
.info => |info| writeInfo(out, info) catch break, },
.msg => |m| writeMsg(out, m) catch break, .pong => {
else => std.debug.panic("unimplemented write", .{}), writePong(out) catch break;
},
.info => |info| {
writeInfo(out, info) catch break;
},
.msg => |m| {
writeMsg(out, m) catch break;
},
else => {
std.debug.panic("unimplemented write", .{});
},
} }
} }
} }
@@ -70,19 +78,27 @@ pub const ClientState = struct {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
in: *std.Io.Reader, in: *std.Io.Reader,
) void { ) void {
io.sleep(.fromMilliseconds(100), .real) catch @panic("couldn't sleep");
while (true) { while (true) {
std.debug.print("waiting for message\n", .{});
const next_message = Message.next(allocator, in) catch |err| switch (err) { const next_message = Message.next(allocator, in) catch |err| switch (err) {
error.EndOfStream => { error.EndOfStream => break,
break;
},
else => { else => {
std.debug.panic("guh: {any}\n", .{err}); std.debug.panic("guh: {any}\n", .{err});
break; break;
// return err; // return err;
}, },
}; };
self.send_queue.putOne(io, next_message) catch break; std.debug.print("got message {}\n", .{next_message});
// std.debug.print("queue: {any}\n", .{self.send_queue});
self.send_queue.putOneUncancelable(io, next_message); //catch {
// std.debug.print("in catch\n\n\n", .{});
// std.debug.print("queue: {any}\n", .{self.send_queue});
// };
} }
std.debug.print("no more messages\n", .{});
} }
pub fn deinit(self: *ClientState, alloc: std.mem.Allocator) void { pub fn deinit(self: *ClientState, alloc: std.mem.Allocator) void {
@@ -102,7 +118,10 @@ pub const ClientState = struct {
return (try self.recv_queue.put(io, &.{msg}, 0)) > 0; return (try self.recv_queue.put(io, &.{msg}, 0)) > 0;
} }
pub fn next(self: *ClientState, io: std.Io) std.Io.Cancelable!Message { pub fn next(self: *ClientState, io: std.Io) !Message {
std.debug.print("in client awaiting next message\n", .{});
errdefer std.debug.print("actually it was canceled\n", .{});
defer std.debug.print("client returning next message!\n", .{});
return self.send_queue.getOne(io); return self.send_queue.getOne(io);
} }
}; };
@@ -113,11 +132,13 @@ fn writeOk(out: *std.Io.Writer) !void {
} }
fn writePong(out: *std.Io.Writer) !void { fn writePong(out: *std.Io.Writer) !void {
std.debug.print("writing pong\n", .{});
_ = try out.write("PONG\r\n"); _ = try out.write("PONG\r\n");
try out.flush(); try out.flush();
} }
pub fn writeInfo(out: *std.Io.Writer, info: Message.ServerInfo) !void { pub fn writeInfo(out: *std.Io.Writer, info: Message.ServerInfo) !void {
std.debug.print("writing info: {any}\n", .{info});
_ = try out.write("INFO "); _ = try out.write("INFO ");
try std.json.Stringify.value(info, .{}, out); try std.json.Stringify.value(info, .{}, out);
_ = try out.write("\r\n"); _ = try out.write("\r\n");
@@ -138,3 +159,61 @@ fn writeMsg(out: *std.Io.Writer, msg: Message.Msg) !void {
); );
try out.flush(); try out.flush();
} }
test {
const io = std.testing.io;
const gpa = std.testing.allocator;
var from_client: std.Io.Reader = .fixed(
"CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\n" ++
"PING\r\n",
);
var from_client_buf: [1024]Message = undefined;
var from_client_queue: std.Io.Queue(Message) = .init(&from_client_buf);
while (Message.next(gpa, &from_client)) |msg| {
try from_client_queue.putOne(io, msg);
} else |_| {}
for (0..2) |_| {
var msg = try from_client_queue.getOne(io);
std.debug.print("Message: {any}\n", .{msg});
switch (msg) {
.connect => |*c| {
c.deinit();
},
else => {},
}
}
// const connect = (Message.next(gpa, &from_client) catch unreachable).connect;
// var to_client_alloc: std.Io.Writer.Allocating = .init(gpa);
// defer to_client_alloc.deinit();
// var to_client = to_client_alloc.writer;
// var client: ClientState = try .init(io, gpa, 0, connect, &from_client, &to_client);
// defer client.deinit(gpa);
// {
// var get_next = io.concurrent(ClientState.next, .{ &client, io }) catch unreachable;
// defer if (get_next.cancel(io)) |_| {} else |_| @panic("fail");
// var timeout = io.concurrent(std.Io.sleep, .{ io, .fromMilliseconds(1000), .awake }) catch unreachable;
// defer timeout.cancel(io) catch {};
// switch (try io.select(.{
// .get_next = &get_next,
// .timeout = &timeout,
// })) {
// .get_next => |next| {
// std.debug.print("next is {any}\n", .{next});
// try std.testing.expect((next catch |err| return err) == .ping);
// },
// .timeout => {
// std.debug.print("reached timeout\n", .{});
// return error.TestUnexpectedResult;
// },
// }
// }
}

View File

@@ -29,7 +29,10 @@ pub fn main(gpa: std.mem.Allocator, server_config: ServerInfo) !void {
while (true) : (id +%= 1) { while (true) : (id +%= 1) {
if (server.clients.contains(id)) continue; if (server.clients.contains(id)) continue;
const stream = try tcp_server.accept(io); const stream = try tcp_server.accept(io);
_ = io.async(handleConnection, .{ &server, gpa, io, id, stream }); _ = io.concurrent(handleConnection, .{ &server, gpa, io, id, stream }) catch {
std.debug.print("could not start concurrent handler for {d}\n", .{id});
stream.close(io);
};
} }
} }
@@ -66,7 +69,7 @@ fn handleConnection(
var connect_arena: std.heap.ArenaAllocator = .init(allocator); var connect_arena: std.heap.ArenaAllocator = .init(allocator);
defer connect_arena.deinit(); defer connect_arena.deinit();
const connect = (Message.next(connect_arena.allocator(), in) catch return).connect; const connect = (Message.next(connect_arena.allocator(), in) catch return).connect;
var client_state: ClientState = .init(io, allocator, id, connect, in, out); var client_state: ClientState = try .init(io, allocator, id, connect, in, out);
try server.addClient(allocator, id, client_state); try server.addClient(allocator, id, client_state);
defer server.removeClient(allocator, id); defer server.removeClient(allocator, id);
@@ -133,16 +136,24 @@ fn subscribe(server: *Server, gpa: std.mem.Allocator, id: usize, msg: Message.Su
try server.subscriptions.put(gpa, msg.subject, subs_for_subject); try server.subscriptions.put(gpa, msg.subject, subs_for_subject);
} }
fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_state: *ClientState) !void { pub fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_state: *ClientState) !void {
defer std.debug.print("done processing client??\n", .{});
defer client_state.deinit(gpa); defer client_state.deinit(gpa);
std.debug.print("processing client: {d}\n", .{client_state.id});
while (true) { while (true) {
switch (try client_state.next(io)) { std.debug.print("awaiting next message from client\n", .{});
switch (client_state.next(io)) {
.ping => { .ping => {
std.debug.print("got a ping! sending a pong.\n", .{});
for (0..5) |_| { for (0..5) |_| {
if (try client_state.send(io, .pong)) break; if (try client_state.send(io, .pong)) {
std.debug.print("sent pong\n", .{});
break;
}
std.debug.print("trying to send a pong again.\n", .{});
} else { } else {
std.debug.print("could not pong to client {}\n", .{client_state.id}); std.debug.print("could not pong to client {d}\n", .{client_state.id});
} }
}, },
.@"pub" => |msg| { .@"pub" => |msg| {
@@ -158,6 +169,8 @@ fn processClient(server: *Server, gpa: std.mem.Allocator, io: std.Io, client_sta
std.debug.panic("Unimplemented message: {any}\n", .{msg}); std.debug.panic("Unimplemented message: {any}\n", .{msg});
}, },
} }
std.debug.print("processed message from client\n", .{});
} }
// while (!io.cancelRequested()) { // while (!io.cancelRequested()) {
@@ -247,3 +260,42 @@ pub fn createId() []const u8 {
pub fn createName() []const u8 { pub fn createName() []const u8 {
return "SERVERNAME"; return "SERVERNAME";
} }
// TESTING
// fn initTestServer() Server {
// return .{
// .info = .{
// .server_id = "ABCD",
// .server_name = "test server",
// .version = "0.1.2",
// .max_payload = 1234,
// },
// };
// }
// fn initTestClient(
// io: std.Io,
// allocator: std.mem.Allocator,
// id: usize,
// data_from: []const u8,
// ) !struct {
// Client,
// *std.Io.Reader,
// *std.Io.Writer,
// } {
// return .init(io, allocator, id, .{}, in, out);
// }
// test {
// const gpa = std.testing.allocator;
// const io = std.testing.io;
// const server = initTestServer();
// const client: Client = .init(
// io,
// gpa,
// 1,
// .{},
// );
// }

View File

@@ -31,9 +31,9 @@ pub const MessageType = enum {
} }
}; };
pub const Message = union(enum) { pub const Message = union(MessageType) {
info: ServerInfo, info: ServerInfo,
connect: Connect, connect: AllocatedConnect,
@"pub": Pub, @"pub": Pub,
hpub: void, hpub: void,
sub: Sub, sub: Sub,
@@ -71,6 +71,14 @@ pub const Message = union(enum) {
/// feature. /// feature.
proto: u32 = 1, proto: u32 = 1,
}; };
pub const AllocatedConnect = struct {
allocator: std.heap.ArenaAllocator,
connect: Connect,
pub fn deinit(self: *AllocatedConnect) void {
self.allocator.deinit();
}
};
pub const Connect = struct { pub const Connect = struct {
verbose: bool = false, verbose: bool = false,
pedantic: bool = false, pedantic: bool = false,
@@ -136,8 +144,20 @@ pub const Message = union(enum) {
/// An error should be handled by cleaning up this connection. /// An error should be handled by cleaning up this connection.
pub fn next(alloc: std.mem.Allocator, in: *std.Io.Reader) !Message { pub fn next(alloc: std.mem.Allocator, in: *std.Io.Reader) !Message {
// errdefer |err| {
// std.debug.print("Error occurred: {}\n", .{err});
// // Get the error return trace
// if (@errorReturnTrace()) |trace| {
// std.debug.print("Error return trace:\n", .{});
// std.debug.dumpStackTrace(trace);
// } else {
// std.debug.print("No error return trace available\n", .{});
// }
// }
var operation_string: std.ArrayList(u8) = blk: { var operation_string: std.ArrayList(u8) = blk: {
var buf: ["CONTINUE".len]u8 = undefined; var buf: ["CONTINUE".len + 1]u8 = undefined;
break :blk .initBuffer(&buf); break :blk .initBuffer(&buf);
}; };
@@ -149,15 +169,15 @@ pub const Message = union(enum) {
} else |err| return err; } else |err| return err;
const operation = parse(operation_string.items) orelse { const operation = parse(operation_string.items) orelse {
std.debug.print("operation: '{s}'\n", .{operation_string.items});
std.debug.print("buffered: '{s}'", .{in.buffered()});
return error.InvalidOperation; return error.InvalidOperation;
}; };
switch (operation) { switch (operation) {
.connect => { .connect => {
// TODO: should be ARENA allocator // TODO: should be ARENA allocator
var connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(alloc, 1024); var connect_arena_allocator: std.heap.ArenaAllocator = .init(alloc);
const connect_allocator = connect_arena_allocator.allocator();
const connect_string_writer_allocating: std.Io.Writer.Allocating = try .initCapacity(connect_allocator, 1024);
var connect_string_writer = connect_string_writer_allocating.writer; var connect_string_writer = connect_string_writer_allocating.writer;
try in.discardAll(1); // throw away space try in.discardAll(1); // throw away space
@@ -167,9 +187,9 @@ pub const Message = union(enum) {
std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n' std.debug.assert(std.mem.eql(u8, try in.take(3), "}\r\n")); // discard '}\r\n'
// TODO: should be CONNECTION allocator // TODO: should be CONNECTION allocator
const res = try std.json.parseFromSliceLeaky(Connect, alloc, connect_string_writer.buffered(), .{ .allocate = .alloc_always }); const res = try std.json.parseFromSliceLeaky(Connect, connect_allocator, connect_string_writer.buffered(), .{ .allocate = .alloc_always });
return .{ .connect = res }; return .{ .connect = .{ .allocator = connect_arena_allocator, .connect = res } };
}, },
.@"pub" => { .@"pub" => {
try in.discardAll(1); // throw away space try in.discardAll(1); // throw away space
@@ -311,67 +331,70 @@ fn parsePub(in: *std.Io.Reader) !Message.Pub {
} }
// try returning error in debug mode, only null in release? // try returning error in debug mode, only null in release?
pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message { // pub fn parseNextMessage(alloc: std.mem.Allocator, in: *std.Io.Reader) ?Message {
const message_type: MessageType = blk: { // const message_type: MessageType = blk: {
var word: ["CONNECT".len]u8 = undefined; // var word: ["CONNECT".len]u8 = undefined;
var len: usize = 0; // var len: usize = 0;
for (&word, 0..) |*b, i| { // for (&word, 0..) |*b, i| {
const byte = in.takeByte() catch return null; // const byte = in.takeByte() catch return null;
if (std.ascii.isUpper(byte)) { // if (std.ascii.isUpper(byte)) {
b.* = byte; // b.* = byte;
len = i + 1; // len = i + 1;
} else break; // } else break;
} // }
break :blk Message.parse(word[0..len]) orelse return null; // break :blk Message.parse(word[0..len]) orelse return null;
}; // };
// defer in.toss(2); // CRLF // // defer in.toss(2); // CRLF
return switch (message_type) { // return switch (message_type) {
.connect => blk: { // .connect => blk: {
const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null }; // const value: ?Message = .{ .connect = parseJsonMessage(Message.Connect, alloc, in) catch return null };
break :blk value; // break :blk value;
}, // },
.@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) }, // .@"pub" => .{ .@"pub" = parsePub(in) catch |err| std.debug.panic("{}", .{err}) },
.ping => .ping, // .ping => .ping,
else => null, // else => null,
}; // };
} // }
test parseNextMessage { // test parseNextMessage {
const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n"; // const input = "CONNECT {\"verbose\":false,\"pedantic\":false,\"tls_required\":false,\"name\":\"NATS CLI Version v0.2.4\",\"lang\":\"go\",\"version\":\"1.43.0\",\"protocol\":1,\"echo\":true,\"headers\":true,\"no_responders\":true}\r\nPUB hi 3\r\nfoo\r\n";
var reader: std.Io.Reader = .fixed(input); // var reader: std.Io.Reader = .fixed(input);
var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); // var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
defer arena.deinit(); // defer arena.deinit();
const gpa = arena.allocator(); // const gpa = arena.allocator();
{ // {
const msg: Message = try Message.next(gpa, &reader); // const msg: Message = try Message.next(gpa, &reader);
const expected: Message = .{ .connect = .{ // const expected: Message = .{ .connect = .{
.verbose = false, // .connect = .{
.pedantic = false, // .verbose = false,
.tls_required = false, // .pedantic = false,
.name = "NATS CLI Version v0.2.4", // .tls_required = false,
.lang = "go", // .name = try gpa.dupe(u8, "NATS CLI Version v0.2.4"),
.version = "1.43.0", // .lang = try gpa.dupe(u8, "go"),
.protocol = 1, // .version = try gpa.dupe(u8, "1.43.0"),
.echo = true, // .protocol = 1,
.headers = true, // .echo = true,
.no_responders = true, // .headers = true,
} }; // .no_responders = true,
// },
// .allocator = arena,
// } };
try std.testing.expectEqualDeep(expected, msg); // try std.testing.expectEqualDeep(expected, msg);
} // }
{ // {
const msg: Message = try Message.next(gpa, &reader); // const msg: Message = try Message.next(gpa, &reader);
const expected: Message = .{ .@"pub" = .{ // const expected: Message = .{ .@"pub" = .{
.subject = "hi", // .subject = "hi",
.payload = "foo", // .payload = "foo",
} }; // } };
try std.testing.expectEqualDeep(expected, msg); // try std.testing.expectEqualDeep(expected, msg);
} // }
} // }
// test "MessageType.parse performance" { // test "MessageType.parse performance" {
// // Measure perf for parseMemEql // // Measure perf for parseMemEql

3
src/server/test.zig Normal file
View File

@@ -0,0 +1,3 @@
const std = @import("std");
const Server = @import("./main.zig");
const Client = @import("./client.zig").ClientState;