add fetching message history

This commit is contained in:
2026-04-11 20:03:09 +02:00
parent 1c7d0a691d
commit 1c58954613
6 changed files with 323 additions and 37 deletions
+5
View File
@@ -4,6 +4,7 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
@@ -32,3 +33,7 @@ func ConvertStringToRgb(str string) ([3]uint8, error) {
func ConvertStringUuid(str string) (uuid.UUID, error) {
return uuid.Parse(str)
}
func ConvertStringTimestamp(str string) (time.Time, error) {
return time.Parse(time.RFC3339, str)
}
+36 -1
View File
@@ -57,7 +57,7 @@ func DbInit(ctx context.Context) {
CREATE TABLE IF NOT EXISTS messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
sender_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
receiver_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
receiver_id UUID NOT NULL REFERENCES user_connections(id) ON DELETE CASCADE,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
content TEXT NOT NULL,
is_group_message BOOLEAN DEFAULT FALSE
@@ -217,12 +217,47 @@ func DbConnectionUpdateState(ctx context.Context, conn *Connection) error {
}
func DbMessageSave(ctx context.Context, message *Message) error {
if message.Id != (uuid.UUID{}) {
_, err := dbConn.Exec(ctx, `
INSERT INTO messages (id, sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5, $6)
`, message.Id, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage)
return err
}
return dbConn.QueryRow(ctx, `
INSERT INTO messages (sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5)
RETURNING id
`, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage).Scan(&message.Id)
}
func DbConnectionGetMessagesBefore(ctx context.Context, before time.Time, connection uuid.UUID, cap uint32) ([]*Message, error) {
rows, err := dbConn.Query(ctx, `
SELECT id, sender_id, receiver_id, created_at, content, is_group_message
FROM (
SELECT id, sender_id, receiver_id, created_at, content, is_group_message
FROM messages
WHERE receiver_id = $1
AND created_at < $2
ORDER BY created_at DESC
LIMIT $3
) sub
ORDER BY created_at ASC
`, connection, before, cap)
if err != nil {
return nil, err
}
defer rows.Close()
messages := make([]*Message, 0, cap)
for rows.Next() {
msg := &Message{}
if err = rows.Scan(&msg.Id, &msg.Sender, &msg.Receiver, &msg.CreatedAt, &msg.Content, &msg.IsGroupMessage); err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, rows.Err()
}
func DbGroupSave(ctx context.Context, group *Group) error {
err := dbConn.QueryRow(ctx, `
INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at)
+81 -2
View File
@@ -14,6 +14,7 @@ import (
"go-socket/Enums/ConnectionState"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@@ -297,23 +298,101 @@ func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request)
return
}
message := &Message{
Id: uuid.New(),
Content: msgContent,
CreatedAt: time.Now(),
Sender: user.Id,
Receiver: target.Id,
Receiver: conn.Id,
IsGroupMessage: false,
}
WsMessageSendToUser(target, message)
err = DbMessageSave(ctx, message)
if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError)
return
}
WsSendToUser(target, message)
response.WriteHeader(http.StatusAccepted)
}
func HttpHandleUserGetMessages(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
}
ctx := request.Context()
user, err := getUserByToken(ctx, request.FormValue("token"))
if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized)
return
}
connectionId, err := ConvertStringUuid(request.FormValue("connectionid"))
if err != nil {
http.Error(response, "invalid connectionid", http.StatusBadRequest)
return
}
before, err := ConvertStringTimestamp(request.FormValue("before"))
if err != nil {
before = time.Now()
}
messagesCap, err := ConvertStringUint32(request.FormValue("messages"))
if err != nil {
messagesCap = MaxDirectMsgCache
}
conn, ok := CacheGetConnection(user, connectionId)
if !ok {
http.Error(response, "invalid connectionid", http.StatusBadRequest)
return
}
buffer, bufferSize := conn.GetSortedMessagesBuff()
var validBufCount uint32
for validBufCount < bufferSize && buffer[validBufCount].CreatedAt.Before(before) {
validBufCount++
}
var messages []*Message
if validBufCount >= messagesCap {
start := validBufCount - messagesCap
messages = make([]*Message, messagesCap)
for i := uint32(0); i < messagesCap; i++ {
messages[i] = buffer[start+i]
}
} else {
remaining := messagesCap - validBufCount
cutoff := before
if validBufCount > 0 {
cutoff = buffer[0].CreatedAt
}
dbMessages, err := DbConnectionGetMessagesBefore(ctx, cutoff, connectionId, remaining)
if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError)
return
}
messages = make([]*Message, 0, uint32(len(dbMessages))+validBufCount)
messages = append(messages, dbMessages...)
for i := uint32(0); i < validBufCount; i++ {
messages = append(messages, buffer[i])
}
}
json, err := json2.Marshal(messages)
if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError)
return
}
response.WriteHeader(http.StatusOK)
response.Write(json)
}
func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
+139 -16
View File
@@ -1,13 +1,20 @@
import asyncio
import json
import threading
import requests
import websockets
from datetime import datetime
from prompt_toolkit import PromptSession
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.keys import Keys
URI = "ws://localhost:8080/ws"
BASE_URL = "http://localhost:8080"
WS_URI = "ws://localhost:8080/ws"
RETRY_DELAY = 2
token: str | None = None
user_id: int | None = None
bindings = KeyBindings()
@bindings.add(Keys.Enter)
@@ -22,9 +29,113 @@ send_queue: asyncio.Queue = None
loop: asyncio.AbstractEventLoop = None
shutdown_event: asyncio.Event = None
# ── helpers ──────────────────────────────────────────────────────────────────
def post(path, **data):
return requests.post(f"{BASE_URL}{path}", data=data)
def fmt_msg(m: dict) -> str:
sender = m.get("sender", "?")
content = m.get("content", "")
ts = m.get("createdAt", "")
try:
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
ts = dt.strftime("%H:%M:%S")
except Exception:
pass
return f"[{ts}] <{sender}> {content}"
# ── commands ─────────────────────────────────────────────────────────────────
def cmd_login(args):
global token, user_id
if len(args) < 2:
print("usage: /login <username> <password>")
return
r = post("/token", username=args[0], password=args[1])
if r.ok:
data = r.json()
token = data["token"]
user_id = data["userId"]
print(f"logged in user_id={user_id}")
asyncio.run_coroutine_threadsafe(
send_queue.put(json.dumps({"token": token})), loop
)
else:
print(f"login failed: {r.text}")
def cmd_send(args):
if not token:
print("not logged in")
return
if len(args) < 2:
print("usage: /send <connectionid> <message…>")
return
conn_id = args[0]
content = " ".join(args[1:])
r = post("/message", token=token, connectionid=conn_id, msgContent=content)
print("sent" if r.ok else f"error: {r.text}")
def cmd_history(args):
if not token:
print("not logged in")
return
if not args:
print("usage: /history <connectionid> [count]")
return
data = {"token": token, "connectionid": args[0]}
if len(args) > 1:
data["messages"] = args[1]
if len(args) > 2:
data["before"] = args[2]
r = requests.post(f"{BASE_URL}/messages", data=data)
if r.ok:
msgs = r.json() or []
if not msgs:
print("no messages")
for m in msgs:
print(fmt_msg(m))
else:
print(f"error: {r.text}")
def cmd_connections(args):
if not token:
print("not logged in")
return
r = post("/connections", token=token)
if r.ok:
for c in (r.json() or {}).values():
print(f" {c['id']} requestor={c['requestorId']} recipient={c['recipientId']} state={c['state']}")
else:
print(f"error: {r.text}")
COMMANDS = {
"/login": cmd_login,
"/send": cmd_send,
"/history": cmd_history,
"/connections": cmd_connections,
}
HELP = """
/login <user> <pass> authenticate
/connections list your connections
/send <connectionid> <message…> send a DM
/history <connectionid> [count] [before] fetch message history
"""
# ── websocket ─────────────────────────────────────────────────────────────────
async def receiver(ws):
async for message in ws:
print(f"\n[SERVER] {message}", flush=True)
async for raw in ws:
try:
data = json.loads(raw)
# pushed DM
if "content" in data and "sender" in data:
print(f"\n{fmt_msg(data)}", flush=True)
else:
print(f"\n[SERVER] {json.dumps(data, indent=2)}", flush=True)
except json.JSONDecodeError:
print(f"\n[SERVER] {raw}", flush=True)
async def sender(ws):
while True:
@@ -42,12 +153,15 @@ async def run():
while not shutdown_event.is_set():
try:
async with websockets.connect(URI) as ws:
print(f"Connected to {URI}")
async with websockets.connect(WS_URI) as ws:
print(f"connected to {WS_URI}")
if not input_thread_started:
print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n")
print("Alt+Enter = newline | Enter = send | /help | Ctrl+C = quit\n")
input_thread.start()
input_thread_started = True
# re-auth after reconnect
if token:
await ws.send(json.dumps({"token": token}))
recv_task = asyncio.create_task(receiver(ws))
send_task = asyncio.create_task(sender(ws))
@@ -57,41 +171,50 @@ async def run():
[recv_task, send_task, shutdown_task],
return_when=asyncio.FIRST_COMPLETED,
)
for t in pending:
t.cancel()
if shutdown_event.is_set():
return
# reconnect
send_queue = asyncio.Queue()
print(f"\n[Disconnected] Reconnecting in {RETRY_DELAY}s...")
print(f"\n[disconnected] reconnecting in {RETRY_DELAY}s…")
except (OSError, websockets.exceptions.WebSocketException):
if shutdown_event.is_set():
return
print(f"[Waiting for server] Retrying in {RETRY_DELAY}s...", flush=True)
print(f"[waiting for server] retrying in {RETRY_DELAY}s…", flush=True)
try:
await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY)
except asyncio.TimeoutError:
pass
# ── input loop ────────────────────────────────────────────────────────────────
def input_loop():
session = PromptSession(
key_bindings=bindings,
multiline=True,
)
session = PromptSession(key_bindings=bindings, multiline=True)
while True:
try:
text = session.prompt(">>> ")
if text is not None:
text = session.prompt(">>> ").strip()
if not text:
continue
if text == "/help":
print(HELP)
continue
parts = text.split()
cmd = parts[0]
if cmd in COMMANDS:
COMMANDS[cmd](parts[1:])
else:
# raw JSON passthrough
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
except (EOFError, KeyboardInterrupt):
asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop)
break
# ── main ──────────────────────────────────────────────────────────────────────
def main():
global loop, shutdown_event
loop = asyncio.new_event_loop()
+32 -2
View File
@@ -24,20 +24,50 @@ type User struct {
}
type Connection struct {
Mu sync.RWMutex `json:"-"`
Id uuid.UUID `json:"id"`
CreatedAt time.Time `json:"createdAt"`
MessagesBuf [MaxDirectMsgCache]*Message `json:"-"`
MessagesBuff [MaxDirectMsgCache]*Message `json:"-"`
NextBuffIdx uint32 `json:"-"`
HaveOverflowed bool `json:"-"`
RequestorId uint32 `json:"requestorId"`
RecipientId uint32 `json:"recipientId"`
State ConnectionState.ConnectionState `json:"state"`
}
func (conn *Connection) AddMessageToBuff(message *Message) {
conn.Mu.Lock()
defer conn.Mu.Unlock()
conn.MessagesBuff[conn.NextBuffIdx%MaxDirectMsgCache] = message
conn.NextBuffIdx++
if conn.NextBuffIdx >= MaxDirectMsgCache {
conn.HaveOverflowed = true
}
}
// GetSortedMessagesBuff returns arr, length
func (conn *Connection) GetSortedMessagesBuff() (*[MaxDirectMsgCache]*Message, uint32) {
conn.Mu.RLock()
defer conn.Mu.RUnlock()
if !conn.HaveOverflowed {
return &conn.MessagesBuff, conn.NextBuffIdx
}
sorted := new([MaxDirectMsgCache]*Message)
for i := uint32(0); i < MaxDirectMsgCache; i++ {
sorted[i] = conn.MessagesBuff[(conn.NextBuffIdx+i)%MaxDirectMsgCache]
}
return sorted, MaxDirectMsgCache
}
type Message struct {
Id uuid.UUID `json:"id"`
Content string `json:"content"`
CreatedAt time.Time `json:"createdAt"`
Sender uint32 `json:"sender"`
Receiver uint32 `json:"receiver"`
Receiver uuid.UUID `json:"receiver"`
IsGroupMessage bool `json:"isGroupMessage"`
}
+23 -9
View File
@@ -2,6 +2,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"log"
"net/http"
@@ -54,6 +55,26 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
}
}
func sendMessageStructCloseIfTimeout(user *User, message *Message) {
if user.WsConn == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
jsonMsg, err := json.Marshal(message)
if err != nil {
log.Printf("json marshal error: %v", err)
return
}
err = wsjson.Write(ctx, user.WsConn, jsonMsg)
if err != nil {
log.Printf("json write error: %v", err)
return
}
}
func sendMessageCloseIfTimeout(user *User, message *map[string]any) {
if user.WsConn == nil {
return
@@ -84,15 +105,8 @@ func sendToAllMessageCloseIfTimeout(message *map[string]any) {
}
}
func WsSendToUser(to *User, message *Message) {
var msg = map[string]any{
"type": WsMessageFrom.DirectMessage,
"id": message.Id,
"from": message.Sender,
"created": message.CreatedAt,
"content": message.Content,
}
sendMessageCloseIfTimeout(to, &msg)
func WsMessageSendToUser(to *User, message *Message) {
sendMessageStructCloseIfTimeout(to, message)
}
func WsSendToGroup(group *Group, sender *User, message string) error {