244 lines
8.0 KiB
Python
244 lines
8.0 KiB
Python
import asyncio
|
||
import json
|
||
import threading
|
||
import requests
|
||
import websockets
|
||
from datetime import datetime
|
||
from prompt_toolkit import PromptSession
|
||
from prompt_toolkit.key_binding import KeyBindings
|
||
from prompt_toolkit.keys import Keys
|
||
|
||
BASE_URL = "http://localhost:8080"
|
||
WS_URI = "ws://localhost:8080/ws"
|
||
RETRY_DELAY = 2
|
||
|
||
token: str | None = None
|
||
user_id: int | None = None
|
||
|
||
bindings = KeyBindings()
|
||
|
||
@bindings.add(Keys.Enter)
|
||
def submit(event):
|
||
event.current_buffer.validate_and_handle()
|
||
|
||
@bindings.add(Keys.Escape, Keys.Enter)
|
||
def newline(event):
|
||
event.current_buffer.insert_text("\n")
|
||
|
||
send_queue: asyncio.Queue = None
|
||
loop: asyncio.AbstractEventLoop = None
|
||
shutdown_event: asyncio.Event = None
|
||
|
||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||
|
||
def post(path, **data):
|
||
return requests.post(f"{BASE_URL}{path}", data=data)
|
||
|
||
def fmt_msg(m: dict) -> str:
|
||
sender = m.get("sender", "?")
|
||
content = m.get("content", "")
|
||
ts = m.get("createdAt", "")
|
||
try:
|
||
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||
ts = dt.strftime("%H:%M:%S")
|
||
except Exception:
|
||
pass
|
||
return f"[{ts}] <{sender}> {content}"
|
||
|
||
# ── commands ─────────────────────────────────────────────────────────────────
|
||
|
||
def cmd_login(args):
|
||
global token, user_id
|
||
if len(args) < 2:
|
||
print("usage: /login <username> <password>")
|
||
return
|
||
r = post("/token", username=args[0], password=args[1])
|
||
if r.ok:
|
||
data = r.json()
|
||
token = data["token"]
|
||
user_id = data["userId"]
|
||
print(f"logged in user_id={user_id}")
|
||
asyncio.run_coroutine_threadsafe(
|
||
send_queue.put(json.dumps({"token": token})), loop
|
||
)
|
||
else:
|
||
print(f"login failed: {r.text}")
|
||
|
||
def cmd_send(args):
|
||
if not token:
|
||
print("not logged in")
|
||
return
|
||
if len(args) < 2:
|
||
print("usage: /send <connectionid> <message…>")
|
||
return
|
||
conn_id = args[0]
|
||
content = " ".join(args[1:])
|
||
r = post("/message", token=token, connectionid=conn_id, msgContent=content)
|
||
print("sent" if r.ok else f"error: {r.text}")
|
||
|
||
def cmd_history(args):
|
||
if not token:
|
||
print("not logged in")
|
||
return
|
||
if not args:
|
||
print("usage: /history <connectionid> [count]")
|
||
return
|
||
data = {"token": token, "connectionid": args[0]}
|
||
if len(args) > 1:
|
||
data["messages"] = args[1]
|
||
if len(args) > 2:
|
||
data["before"] = args[2]
|
||
r = requests.post(f"{BASE_URL}/get/connection/messages", data=data)
|
||
if r.ok:
|
||
msgs = r.json() or []
|
||
if not msgs:
|
||
print("no messages")
|
||
for m in msgs:
|
||
print(fmt_msg(m))
|
||
else:
|
||
print(f"error: {r.text}")
|
||
|
||
def cmd_connections(args):
|
||
if not token:
|
||
print("not logged in")
|
||
return
|
||
r = post("/get/connections", token=token)
|
||
if r.ok:
|
||
for c in (r.json() or {}).values():
|
||
print(f" {c['id']} requestor={c['requestorId']} recipient={c['recipientId']} state={c['state']}")
|
||
else:
|
||
print(f"error: {r.text}")
|
||
|
||
def cmd_delconnection(args):
|
||
if not token:
|
||
print("not logged in")
|
||
return
|
||
if not args:
|
||
print("usage: /delconnection <connectionid>")
|
||
return
|
||
r = post("/del/connection", token=token, connectionid=args[0])
|
||
print("deleted" if r.ok else f"error: {r.text}")
|
||
|
||
COMMANDS = {
|
||
"/login": cmd_login,
|
||
"/send": cmd_send,
|
||
"/history": cmd_history,
|
||
"/connections": cmd_connections,
|
||
"/delconnection": cmd_delconnection,
|
||
}
|
||
|
||
HELP = """
|
||
/login <user> <pass> – authenticate
|
||
/connections – list your connections
|
||
/send <connectionid> <message…> – send a DM
|
||
/history <connectionid> [count] [before] – fetch message history
|
||
/delconnection <connectionid> – delete a connection
|
||
"""
|
||
|
||
# ── websocket ─────────────────────────────────────────────────────────────────
|
||
|
||
async def receiver(ws):
|
||
async for raw in ws:
|
||
try:
|
||
data = json.loads(raw)
|
||
# pushed DM
|
||
if "content" in data and "sender" in data:
|
||
print(f"\n{fmt_msg(data)}", flush=True)
|
||
else:
|
||
print(f"\n[SERVER] {json.dumps(data, indent=2)}", flush=True)
|
||
except json.JSONDecodeError:
|
||
print(f"\n[SERVER] {raw}", flush=True)
|
||
|
||
async def sender(ws):
|
||
while True:
|
||
msg = await send_queue.get()
|
||
if msg is None:
|
||
break
|
||
await ws.send(msg)
|
||
|
||
async def run():
|
||
global send_queue
|
||
send_queue = asyncio.Queue()
|
||
|
||
input_thread = threading.Thread(target=input_loop, daemon=True)
|
||
input_thread_started = False
|
||
|
||
while not shutdown_event.is_set():
|
||
try:
|
||
async with websockets.connect(WS_URI) as ws:
|
||
print(f"connected to {WS_URI}")
|
||
if not input_thread_started:
|
||
print("Alt+Enter = newline | Enter = send | /help | Ctrl+C = quit\n")
|
||
input_thread.start()
|
||
input_thread_started = True
|
||
# re-auth after reconnect
|
||
if token:
|
||
await ws.send(json.dumps({"token": token}))
|
||
|
||
recv_task = asyncio.create_task(receiver(ws))
|
||
send_task = asyncio.create_task(sender(ws))
|
||
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
||
|
||
done, pending = await asyncio.wait(
|
||
[recv_task, send_task, shutdown_task],
|
||
return_when=asyncio.FIRST_COMPLETED,
|
||
)
|
||
for t in pending:
|
||
t.cancel()
|
||
|
||
if shutdown_event.is_set():
|
||
return
|
||
|
||
send_queue = asyncio.Queue()
|
||
print(f"\n[disconnected] reconnecting in {RETRY_DELAY}s…")
|
||
|
||
except (OSError, websockets.exceptions.WebSocketException):
|
||
if shutdown_event.is_set():
|
||
return
|
||
print(f"[waiting for server] retrying in {RETRY_DELAY}s…", flush=True)
|
||
|
||
try:
|
||
await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
|
||
# ── input loop ────────────────────────────────────────────────────────────────
|
||
|
||
def input_loop():
|
||
session = PromptSession(key_bindings=bindings, multiline=True)
|
||
while True:
|
||
try:
|
||
text = session.prompt(">>> ").strip()
|
||
if not text:
|
||
continue
|
||
if text == "/help":
|
||
print(HELP)
|
||
continue
|
||
parts = text.split()
|
||
cmd = parts[0]
|
||
if cmd in COMMANDS:
|
||
COMMANDS[cmd](parts[1:])
|
||
else:
|
||
# raw JSON passthrough
|
||
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
|
||
except (EOFError, KeyboardInterrupt):
|
||
asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop)
|
||
break
|
||
|
||
# ── main ──────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
global loop, shutdown_event
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
shutdown_event = asyncio.Event()
|
||
try:
|
||
loop.run_until_complete(run())
|
||
except KeyboardInterrupt:
|
||
pass
|
||
finally:
|
||
loop.close()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|