diff --git a/clientTest.py b/clientTest.py index 7f114c3..e62b6e7 100644 --- a/clientTest.py +++ b/clientTest.py @@ -6,6 +6,7 @@ from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.keys import Keys URI = "ws://localhost:8080/ws" +RETRY_DELAY = 2 bindings = KeyBindings() @@ -19,6 +20,7 @@ def newline(event): send_queue: asyncio.Queue = None loop: asyncio.AbstractEventLoop = None +shutdown_event: asyncio.Event = None async def receiver(ws): async for message in ws: @@ -34,14 +36,47 @@ async def sender(ws): async def run(): global send_queue send_queue = asyncio.Queue() - async with websockets.connect(URI) as ws: - print(f"Connected to {URI}") - print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n") - recv_task = asyncio.create_task(receiver(ws)) - send_task = asyncio.create_task(sender(ws)) - input_thread = threading.Thread(target=input_loop, daemon=True) - input_thread.start() - await asyncio.gather(recv_task, send_task) + + input_thread = threading.Thread(target=input_loop, daemon=True) + input_thread_started = False + + while not shutdown_event.is_set(): + try: + async with websockets.connect(URI) as ws: + print(f"Connected to {URI}") + if not input_thread_started: + print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n") + input_thread.start() + input_thread_started = True + + recv_task = asyncio.create_task(receiver(ws)) + send_task = asyncio.create_task(sender(ws)) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [recv_task, send_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + for t in pending: + t.cancel() + + if shutdown_event.is_set(): + return + + # reconnect + send_queue = asyncio.Queue() + print(f"\n[Disconnected] Reconnecting in {RETRY_DELAY}s...") + + except (OSError, websockets.exceptions.WebSocketException): + if shutdown_event.is_set(): + return + print(f"[Waiting for server] Retrying in {RETRY_DELAY}s...", flush=True) + + try: + await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY) + except asyncio.TimeoutError: + pass def input_loop(): session = PromptSession( @@ -54,17 +89,20 @@ def input_loop(): if text is not None: asyncio.run_coroutine_threadsafe(send_queue.put(text), loop) except (EOFError, KeyboardInterrupt): - asyncio.run_coroutine_threadsafe(send_queue.put(None), loop) + asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop) break def main(): - global loop + global loop, shutdown_event loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + shutdown_event = asyncio.Event() try: loop.run_until_complete(run()) except KeyboardInterrupt: pass + finally: + loop.close() if __name__ == "__main__": main() diff --git a/go-socket b/go-socket index 44c8106..4b14da8 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index 4bf477f..394f566 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,7 @@ func main() { log.Printf("received: %v\n", msg) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := wsjson.Write(ctx, conn, map[string]any{"echo": msg}); err != nil { + if err := wsjson.Write(ctx, conn, msg); err != nil { log.Println("write error:", err) } }, diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..5155820 --- /dev/null +++ b/tokens.go @@ -0,0 +1,23 @@ +package main + +import ( + _ "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + _ "fmt" + _ "time" +) + +func hashToken(token string) string { + var hash = sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} + +func generateToken() string { + var bytes = make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes) +}