From 1c58954613f8f71ecc1f3155d1dbea9c97209c95 Mon Sep 17 00:00:00 2001 From: Sisi Date: Sat, 11 Apr 2026 20:03:09 +0200 Subject: [PATCH] add fetching message history --- convertions.go | 5 ++ database.go | 37 +++++++- http.go | 83 +++++++++++++++++- machine-client/clientTest.py | 159 +++++++++++++++++++++++++++++++---- structs.go | 44 ++++++++-- wsServer.go | 32 +++++-- 6 files changed, 323 insertions(+), 37 deletions(-) diff --git a/convertions.go b/convertions.go index f2c295e..dcd6110 100644 --- a/convertions.go +++ b/convertions.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/google/uuid" ) @@ -32,3 +33,7 @@ func ConvertStringToRgb(str string) ([3]uint8, error) { func ConvertStringUuid(str string) (uuid.UUID, error) { return uuid.Parse(str) } + +func ConvertStringTimestamp(str string) (time.Time, error) { + return time.Parse(time.RFC3339, str) +} diff --git a/database.go b/database.go index 72a44f5..ccb1154 100644 --- a/database.go +++ b/database.go @@ -57,7 +57,7 @@ func DbInit(ctx context.Context) { CREATE TABLE IF NOT EXISTS messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), sender_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - receiver_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + receiver_id UUID NOT NULL REFERENCES user_connections(id) ON DELETE CASCADE, created_at TIMESTAMP NOT NULL DEFAULT NOW(), content TEXT NOT NULL, is_group_message BOOLEAN DEFAULT FALSE @@ -217,12 +217,47 @@ func DbConnectionUpdateState(ctx context.Context, conn *Connection) error { } func DbMessageSave(ctx context.Context, message *Message) error { + if message.Id != (uuid.UUID{}) { + _, err := dbConn.Exec(ctx, ` + INSERT INTO messages (id, sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5, $6) + `, message.Id, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage) + return err + } return dbConn.QueryRow(ctx, ` INSERT INTO messages (sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5) RETURNING id `, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage).Scan(&message.Id) } +func DbConnectionGetMessagesBefore(ctx context.Context, before time.Time, connection uuid.UUID, cap uint32) ([]*Message, error) { + rows, err := dbConn.Query(ctx, ` + SELECT id, sender_id, receiver_id, created_at, content, is_group_message + FROM ( + SELECT id, sender_id, receiver_id, created_at, content, is_group_message + FROM messages + WHERE receiver_id = $1 + AND created_at < $2 + ORDER BY created_at DESC + LIMIT $3 + ) sub + ORDER BY created_at ASC + `, connection, before, cap) + if err != nil { + return nil, err + } + defer rows.Close() + + messages := make([]*Message, 0, cap) + for rows.Next() { + msg := &Message{} + if err = rows.Scan(&msg.Id, &msg.Sender, &msg.Receiver, &msg.CreatedAt, &msg.Content, &msg.IsGroupMessage); err != nil { + return nil, err + } + messages = append(messages, msg) + } + return messages, rows.Err() +} + func DbGroupSave(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at) diff --git a/http.go b/http.go index 24fb619..c253933 100644 --- a/http.go +++ b/http.go @@ -14,6 +14,7 @@ import ( "go-socket/Enums/ConnectionState" + "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) @@ -297,23 +298,101 @@ func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request) return } message := &Message{ + Id: uuid.New(), Content: msgContent, CreatedAt: time.Now(), Sender: user.Id, - Receiver: target.Id, + Receiver: conn.Id, IsGroupMessage: false, } + WsMessageSendToUser(target, message) + err = DbMessageSave(ctx, message) if err != nil { http.Error(response, "internal server error", http.StatusInternalServerError) return } - WsSendToUser(target, message) response.WriteHeader(http.StatusAccepted) } +func HttpHandleUserGetMessages(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + ctx := request.Context() + user, err := getUserByToken(ctx, request.FormValue("token")) + if err != nil { + http.Error(response, "invalid token", http.StatusUnauthorized) + return + } + + connectionId, err := ConvertStringUuid(request.FormValue("connectionid")) + if err != nil { + http.Error(response, "invalid connectionid", http.StatusBadRequest) + return + } + + before, err := ConvertStringTimestamp(request.FormValue("before")) + if err != nil { + before = time.Now() + } + + messagesCap, err := ConvertStringUint32(request.FormValue("messages")) + if err != nil { + messagesCap = MaxDirectMsgCache + } + + conn, ok := CacheGetConnection(user, connectionId) + if !ok { + http.Error(response, "invalid connectionid", http.StatusBadRequest) + return + } + + buffer, bufferSize := conn.GetSortedMessagesBuff() + + var validBufCount uint32 + for validBufCount < bufferSize && buffer[validBufCount].CreatedAt.Before(before) { + validBufCount++ + } + + var messages []*Message + + if validBufCount >= messagesCap { + start := validBufCount - messagesCap + messages = make([]*Message, messagesCap) + for i := uint32(0); i < messagesCap; i++ { + messages[i] = buffer[start+i] + } + } else { + remaining := messagesCap - validBufCount + cutoff := before + if validBufCount > 0 { + cutoff = buffer[0].CreatedAt + } + dbMessages, err := DbConnectionGetMessagesBefore(ctx, cutoff, connectionId, remaining) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + messages = make([]*Message, 0, uint32(len(dbMessages))+validBufCount) + messages = append(messages, dbMessages...) + for i := uint32(0); i < validBufCount; i++ { + messages = append(messages, buffer[i]) + } + } + + json, err := json2.Marshal(messages) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + response.WriteHeader(http.StatusOK) + response.Write(json) +} + func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Request) { if !isMethodAllowed(&response, request) { return diff --git a/machine-client/clientTest.py b/machine-client/clientTest.py index e62b6e7..023cc89 100644 --- a/machine-client/clientTest.py +++ b/machine-client/clientTest.py @@ -1,13 +1,20 @@ import asyncio +import json import threading +import requests import websockets +from datetime import datetime from prompt_toolkit import PromptSession from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.keys import Keys -URI = "ws://localhost:8080/ws" +BASE_URL = "http://localhost:8080" +WS_URI = "ws://localhost:8080/ws" RETRY_DELAY = 2 +token: str | None = None +user_id: int | None = None + bindings = KeyBindings() @bindings.add(Keys.Enter) @@ -22,9 +29,113 @@ send_queue: asyncio.Queue = None loop: asyncio.AbstractEventLoop = None shutdown_event: asyncio.Event = None +# ── helpers ────────────────────────────────────────────────────────────────── + +def post(path, **data): + return requests.post(f"{BASE_URL}{path}", data=data) + +def fmt_msg(m: dict) -> str: + sender = m.get("sender", "?") + content = m.get("content", "") + ts = m.get("createdAt", "") + try: + dt = datetime.fromisoformat(ts.replace("Z", "+00:00")) + ts = dt.strftime("%H:%M:%S") + except Exception: + pass + return f"[{ts}] <{sender}> {content}" + +# ── commands ───────────────────────────────────────────────────────────────── + +def cmd_login(args): + global token, user_id + if len(args) < 2: + print("usage: /login ") + return + r = post("/token", username=args[0], password=args[1]) + if r.ok: + data = r.json() + token = data["token"] + user_id = data["userId"] + print(f"logged in user_id={user_id}") + asyncio.run_coroutine_threadsafe( + send_queue.put(json.dumps({"token": token})), loop + ) + else: + print(f"login failed: {r.text}") + +def cmd_send(args): + if not token: + print("not logged in") + return + if len(args) < 2: + print("usage: /send ") + return + conn_id = args[0] + content = " ".join(args[1:]) + r = post("/message", token=token, connectionid=conn_id, msgContent=content) + print("sent" if r.ok else f"error: {r.text}") + +def cmd_history(args): + if not token: + print("not logged in") + return + if not args: + print("usage: /history [count]") + return + data = {"token": token, "connectionid": args[0]} + if len(args) > 1: + data["messages"] = args[1] + if len(args) > 2: + data["before"] = args[2] + r = requests.post(f"{BASE_URL}/messages", data=data) + if r.ok: + msgs = r.json() or [] + if not msgs: + print("no messages") + for m in msgs: + print(fmt_msg(m)) + else: + print(f"error: {r.text}") + +def cmd_connections(args): + if not token: + print("not logged in") + return + r = post("/connections", token=token) + if r.ok: + for c in (r.json() or {}).values(): + print(f" {c['id']} requestor={c['requestorId']} recipient={c['recipientId']} state={c['state']}") + else: + print(f"error: {r.text}") + +COMMANDS = { + "/login": cmd_login, + "/send": cmd_send, + "/history": cmd_history, + "/connections": cmd_connections, +} + +HELP = """ + /login – authenticate + /connections – list your connections + /send – send a DM + /history [count] [before] – fetch message history +""" + +# ── websocket ───────────────────────────────────────────────────────────────── + async def receiver(ws): - async for message in ws: - print(f"\n[SERVER] {message}", flush=True) + async for raw in ws: + try: + data = json.loads(raw) + # pushed DM + if "content" in data and "sender" in data: + print(f"\n{fmt_msg(data)}", flush=True) + else: + print(f"\n[SERVER] {json.dumps(data, indent=2)}", flush=True) + except json.JSONDecodeError: + print(f"\n[SERVER] {raw}", flush=True) async def sender(ws): while True: @@ -42,56 +153,68 @@ async def run(): while not shutdown_event.is_set(): try: - async with websockets.connect(URI) as ws: - print(f"Connected to {URI}") + async with websockets.connect(WS_URI) as ws: + print(f"connected to {WS_URI}") if not input_thread_started: - print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n") + print("Alt+Enter = newline | Enter = send | /help | Ctrl+C = quit\n") input_thread.start() input_thread_started = True + # re-auth after reconnect + if token: + await ws.send(json.dumps({"token": token})) - recv_task = asyncio.create_task(receiver(ws)) - send_task = asyncio.create_task(sender(ws)) + recv_task = asyncio.create_task(receiver(ws)) + send_task = asyncio.create_task(sender(ws)) shutdown_task = asyncio.create_task(shutdown_event.wait()) done, pending = await asyncio.wait( [recv_task, send_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED, ) - for t in pending: t.cancel() if shutdown_event.is_set(): return - # reconnect send_queue = asyncio.Queue() - print(f"\n[Disconnected] Reconnecting in {RETRY_DELAY}s...") + print(f"\n[disconnected] reconnecting in {RETRY_DELAY}s…") except (OSError, websockets.exceptions.WebSocketException): if shutdown_event.is_set(): return - print(f"[Waiting for server] Retrying in {RETRY_DELAY}s...", flush=True) + print(f"[waiting for server] retrying in {RETRY_DELAY}s…", flush=True) try: await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY) except asyncio.TimeoutError: pass +# ── input loop ──────────────────────────────────────────────────────────────── + def input_loop(): - session = PromptSession( - key_bindings=bindings, - multiline=True, - ) + session = PromptSession(key_bindings=bindings, multiline=True) while True: try: - text = session.prompt(">>> ") - if text is not None: + text = session.prompt(">>> ").strip() + if not text: + continue + if text == "/help": + print(HELP) + continue + parts = text.split() + cmd = parts[0] + if cmd in COMMANDS: + COMMANDS[cmd](parts[1:]) + else: + # raw JSON passthrough asyncio.run_coroutine_threadsafe(send_queue.put(text), loop) except (EOFError, KeyboardInterrupt): asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop) break +# ── main ────────────────────────────────────────────────────────────────────── + def main(): global loop, shutdown_event loop = asyncio.new_event_loop() diff --git a/structs.go b/structs.go index 9371d04..ea80f56 100644 --- a/structs.go +++ b/structs.go @@ -24,12 +24,42 @@ type User struct { } type Connection struct { - Id uuid.UUID `json:"id"` - CreatedAt time.Time `json:"createdAt"` - MessagesBuf [MaxDirectMsgCache]*Message `json:"-"` - RequestorId uint32 `json:"requestorId"` - RecipientId uint32 `json:"recipientId"` - State ConnectionState.ConnectionState `json:"state"` + Mu sync.RWMutex `json:"-"` + Id uuid.UUID `json:"id"` + CreatedAt time.Time `json:"createdAt"` + MessagesBuff [MaxDirectMsgCache]*Message `json:"-"` + NextBuffIdx uint32 `json:"-"` + HaveOverflowed bool `json:"-"` + RequestorId uint32 `json:"requestorId"` + RecipientId uint32 `json:"recipientId"` + State ConnectionState.ConnectionState `json:"state"` +} + +func (conn *Connection) AddMessageToBuff(message *Message) { + conn.Mu.Lock() + defer conn.Mu.Unlock() + + conn.MessagesBuff[conn.NextBuffIdx%MaxDirectMsgCache] = message + conn.NextBuffIdx++ + if conn.NextBuffIdx >= MaxDirectMsgCache { + conn.HaveOverflowed = true + } +} + +// GetSortedMessagesBuff returns arr, length +func (conn *Connection) GetSortedMessagesBuff() (*[MaxDirectMsgCache]*Message, uint32) { + conn.Mu.RLock() + defer conn.Mu.RUnlock() + + if !conn.HaveOverflowed { + return &conn.MessagesBuff, conn.NextBuffIdx + } + + sorted := new([MaxDirectMsgCache]*Message) + for i := uint32(0); i < MaxDirectMsgCache; i++ { + sorted[i] = conn.MessagesBuff[(conn.NextBuffIdx+i)%MaxDirectMsgCache] + } + return sorted, MaxDirectMsgCache } type Message struct { @@ -37,7 +67,7 @@ type Message struct { Content string `json:"content"` CreatedAt time.Time `json:"createdAt"` Sender uint32 `json:"sender"` - Receiver uint32 `json:"receiver"` + Receiver uuid.UUID `json:"receiver"` IsGroupMessage bool `json:"isGroupMessage"` } diff --git a/wsServer.go b/wsServer.go index 8557ef0..31294a7 100644 --- a/wsServer.go +++ b/wsServer.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "errors" "log" "net/http" @@ -54,6 +55,26 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request } } +func sendMessageStructCloseIfTimeout(user *User, message *Message) { + if user.WsConn == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + jsonMsg, err := json.Marshal(message) + if err != nil { + log.Printf("json marshal error: %v", err) + return + } + err = wsjson.Write(ctx, user.WsConn, jsonMsg) + if err != nil { + log.Printf("json write error: %v", err) + return + } +} + func sendMessageCloseIfTimeout(user *User, message *map[string]any) { if user.WsConn == nil { return @@ -84,15 +105,8 @@ func sendToAllMessageCloseIfTimeout(message *map[string]any) { } } -func WsSendToUser(to *User, message *Message) { - var msg = map[string]any{ - "type": WsMessageFrom.DirectMessage, - "id": message.Id, - "from": message.Sender, - "created": message.CreatedAt, - "content": message.Content, - } - sendMessageCloseIfTimeout(to, &msg) +func WsMessageSendToUser(to *User, message *Message) { + sendMessageStructCloseIfTimeout(to, message) } func WsSendToGroup(group *Group, sender *User, message string) error {