diff --git a/clientTest.py b/clientTest.py new file mode 100644 index 0000000..7f114c3 --- /dev/null +++ b/clientTest.py @@ -0,0 +1,70 @@ +import asyncio +import threading +import websockets +from prompt_toolkit import PromptSession +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.keys import Keys + +URI = "ws://localhost:8080/ws" + +bindings = KeyBindings() + +@bindings.add(Keys.Enter) +def submit(event): + event.current_buffer.validate_and_handle() + +@bindings.add(Keys.Escape, Keys.Enter) +def newline(event): + event.current_buffer.insert_text("\n") + +send_queue: asyncio.Queue = None +loop: asyncio.AbstractEventLoop = None + +async def receiver(ws): + async for message in ws: + print(f"\n[SERVER] {message}", flush=True) + +async def sender(ws): + while True: + msg = await send_queue.get() + if msg is None: + break + await ws.send(msg) + +async def run(): + global send_queue + send_queue = asyncio.Queue() + async with websockets.connect(URI) as ws: + print(f"Connected to {URI}") + print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n") + recv_task = asyncio.create_task(receiver(ws)) + send_task = asyncio.create_task(sender(ws)) + input_thread = threading.Thread(target=input_loop, daemon=True) + input_thread.start() + await asyncio.gather(recv_task, send_task) + +def input_loop(): + session = PromptSession( + key_bindings=bindings, + multiline=True, + ) + while True: + try: + text = session.prompt(">>> ") + if text is not None: + asyncio.run_coroutine_threadsafe(send_queue.put(text), loop) + except (EOFError, KeyboardInterrupt): + asyncio.run_coroutine_threadsafe(send_queue.put(None), loop) + break + +def main(): + global loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(run()) + except KeyboardInterrupt: + pass + +if __name__ == "__main__": + main() diff --git a/go-socket b/go-socket index 853a0a3..44c8106 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index ea78ed5..4bf477f 100644 --- a/main.go +++ b/main.go @@ -30,8 +30,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.OnOpen(conn) } - ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) - defer cancel() + ctx := r.Context() var readErr error for {