add fetching message history
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -32,3 +33,7 @@ func ConvertStringToRgb(str string) ([3]uint8, error) {
|
|||||||
func ConvertStringUuid(str string) (uuid.UUID, error) {
|
func ConvertStringUuid(str string) (uuid.UUID, error) {
|
||||||
return uuid.Parse(str)
|
return uuid.Parse(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConvertStringTimestamp(str string) (time.Time, error) {
|
||||||
|
return time.Parse(time.RFC3339, str)
|
||||||
|
}
|
||||||
|
|||||||
+36
-1
@@ -57,7 +57,7 @@ func DbInit(ctx context.Context) {
|
|||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
sender_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
sender_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
receiver_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
receiver_id UUID NOT NULL REFERENCES user_connections(id) ON DELETE CASCADE,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL,
|
||||||
is_group_message BOOLEAN DEFAULT FALSE
|
is_group_message BOOLEAN DEFAULT FALSE
|
||||||
@@ -217,12 +217,47 @@ func DbConnectionUpdateState(ctx context.Context, conn *Connection) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DbMessageSave(ctx context.Context, message *Message) error {
|
func DbMessageSave(ctx context.Context, message *Message) error {
|
||||||
|
if message.Id != (uuid.UUID{}) {
|
||||||
|
_, err := dbConn.Exec(ctx, `
|
||||||
|
INSERT INTO messages (id, sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
`, message.Id, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage)
|
||||||
|
return err
|
||||||
|
}
|
||||||
return dbConn.QueryRow(ctx, `
|
return dbConn.QueryRow(ctx, `
|
||||||
INSERT INTO messages (sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5)
|
INSERT INTO messages (sender_id, receiver_id, created_at, content, is_group_message) VALUES ($1, $2, $3, $4, $5)
|
||||||
RETURNING id
|
RETURNING id
|
||||||
`, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage).Scan(&message.Id)
|
`, message.Sender, message.Receiver, message.CreatedAt, message.Content, message.IsGroupMessage).Scan(&message.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DbConnectionGetMessagesBefore(ctx context.Context, before time.Time, connection uuid.UUID, cap uint32) ([]*Message, error) {
|
||||||
|
rows, err := dbConn.Query(ctx, `
|
||||||
|
SELECT id, sender_id, receiver_id, created_at, content, is_group_message
|
||||||
|
FROM (
|
||||||
|
SELECT id, sender_id, receiver_id, created_at, content, is_group_message
|
||||||
|
FROM messages
|
||||||
|
WHERE receiver_id = $1
|
||||||
|
AND created_at < $2
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $3
|
||||||
|
) sub
|
||||||
|
ORDER BY created_at ASC
|
||||||
|
`, connection, before, cap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
messages := make([]*Message, 0, cap)
|
||||||
|
for rows.Next() {
|
||||||
|
msg := &Message{}
|
||||||
|
if err = rows.Scan(&msg.Id, &msg.Sender, &msg.Receiver, &msg.CreatedAt, &msg.Content, &msg.IsGroupMessage); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages = append(messages, msg)
|
||||||
|
}
|
||||||
|
return messages, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func DbGroupSave(ctx context.Context, group *Group) error {
|
func DbGroupSave(ctx context.Context, group *Group) error {
|
||||||
err := dbConn.QueryRow(ctx, `
|
err := dbConn.QueryRow(ctx, `
|
||||||
INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at)
|
INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"go-socket/Enums/ConnectionState"
|
"go-socket/Enums/ConnectionState"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -297,23 +298,101 @@ func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
message := &Message{
|
message := &Message{
|
||||||
|
Id: uuid.New(),
|
||||||
Content: msgContent,
|
Content: msgContent,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
Sender: user.Id,
|
Sender: user.Id,
|
||||||
Receiver: target.Id,
|
Receiver: conn.Id,
|
||||||
IsGroupMessage: false,
|
IsGroupMessage: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WsMessageSendToUser(target, message)
|
||||||
|
|
||||||
err = DbMessageSave(ctx, message)
|
err = DbMessageSave(ctx, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(response, "internal server error", http.StatusInternalServerError)
|
http.Error(response, "internal server error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
WsSendToUser(target, message)
|
|
||||||
response.WriteHeader(http.StatusAccepted)
|
response.WriteHeader(http.StatusAccepted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HttpHandleUserGetMessages(response http.ResponseWriter, request *http.Request) {
|
||||||
|
if !isMethodAllowed(&response, request) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := request.Context()
|
||||||
|
user, err := getUserByToken(ctx, request.FormValue("token"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "invalid token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectionId, err := ConvertStringUuid(request.FormValue("connectionid"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "invalid connectionid", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
before, err := ConvertStringTimestamp(request.FormValue("before"))
|
||||||
|
if err != nil {
|
||||||
|
before = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesCap, err := ConvertStringUint32(request.FormValue("messages"))
|
||||||
|
if err != nil {
|
||||||
|
messagesCap = MaxDirectMsgCache
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, ok := CacheGetConnection(user, connectionId)
|
||||||
|
if !ok {
|
||||||
|
http.Error(response, "invalid connectionid", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer, bufferSize := conn.GetSortedMessagesBuff()
|
||||||
|
|
||||||
|
var validBufCount uint32
|
||||||
|
for validBufCount < bufferSize && buffer[validBufCount].CreatedAt.Before(before) {
|
||||||
|
validBufCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []*Message
|
||||||
|
|
||||||
|
if validBufCount >= messagesCap {
|
||||||
|
start := validBufCount - messagesCap
|
||||||
|
messages = make([]*Message, messagesCap)
|
||||||
|
for i := uint32(0); i < messagesCap; i++ {
|
||||||
|
messages[i] = buffer[start+i]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
remaining := messagesCap - validBufCount
|
||||||
|
cutoff := before
|
||||||
|
if validBufCount > 0 {
|
||||||
|
cutoff = buffer[0].CreatedAt
|
||||||
|
}
|
||||||
|
dbMessages, err := DbConnectionGetMessagesBefore(ctx, cutoff, connectionId, remaining)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages = make([]*Message, 0, uint32(len(dbMessages))+validBufCount)
|
||||||
|
messages = append(messages, dbMessages...)
|
||||||
|
for i := uint32(0); i < validBufCount; i++ {
|
||||||
|
messages = append(messages, buffer[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
json, err := json2.Marshal(messages)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.WriteHeader(http.StatusOK)
|
||||||
|
response.Write(json)
|
||||||
|
}
|
||||||
|
|
||||||
func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Request) {
|
func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Request) {
|
||||||
if !isMethodAllowed(&response, request) {
|
if !isMethodAllowed(&response, request) {
|
||||||
return
|
return
|
||||||
|
|||||||
+139
-16
@@ -1,13 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
|
import requests
|
||||||
import websockets
|
import websockets
|
||||||
|
from datetime import datetime
|
||||||
from prompt_toolkit import PromptSession
|
from prompt_toolkit import PromptSession
|
||||||
from prompt_toolkit.key_binding import KeyBindings
|
from prompt_toolkit.key_binding import KeyBindings
|
||||||
from prompt_toolkit.keys import Keys
|
from prompt_toolkit.keys import Keys
|
||||||
|
|
||||||
URI = "ws://localhost:8080/ws"
|
BASE_URL = "http://localhost:8080"
|
||||||
|
WS_URI = "ws://localhost:8080/ws"
|
||||||
RETRY_DELAY = 2
|
RETRY_DELAY = 2
|
||||||
|
|
||||||
|
token: str | None = None
|
||||||
|
user_id: int | None = None
|
||||||
|
|
||||||
bindings = KeyBindings()
|
bindings = KeyBindings()
|
||||||
|
|
||||||
@bindings.add(Keys.Enter)
|
@bindings.add(Keys.Enter)
|
||||||
@@ -22,9 +29,113 @@ send_queue: asyncio.Queue = None
|
|||||||
loop: asyncio.AbstractEventLoop = None
|
loop: asyncio.AbstractEventLoop = None
|
||||||
shutdown_event: asyncio.Event = None
|
shutdown_event: asyncio.Event = None
|
||||||
|
|
||||||
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def post(path, **data):
|
||||||
|
return requests.post(f"{BASE_URL}{path}", data=data)
|
||||||
|
|
||||||
|
def fmt_msg(m: dict) -> str:
|
||||||
|
sender = m.get("sender", "?")
|
||||||
|
content = m.get("content", "")
|
||||||
|
ts = m.get("createdAt", "")
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||||||
|
ts = dt.strftime("%H:%M:%S")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return f"[{ts}] <{sender}> {content}"
|
||||||
|
|
||||||
|
# ── commands ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def cmd_login(args):
|
||||||
|
global token, user_id
|
||||||
|
if len(args) < 2:
|
||||||
|
print("usage: /login <username> <password>")
|
||||||
|
return
|
||||||
|
r = post("/token", username=args[0], password=args[1])
|
||||||
|
if r.ok:
|
||||||
|
data = r.json()
|
||||||
|
token = data["token"]
|
||||||
|
user_id = data["userId"]
|
||||||
|
print(f"logged in user_id={user_id}")
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
send_queue.put(json.dumps({"token": token})), loop
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"login failed: {r.text}")
|
||||||
|
|
||||||
|
def cmd_send(args):
|
||||||
|
if not token:
|
||||||
|
print("not logged in")
|
||||||
|
return
|
||||||
|
if len(args) < 2:
|
||||||
|
print("usage: /send <connectionid> <message…>")
|
||||||
|
return
|
||||||
|
conn_id = args[0]
|
||||||
|
content = " ".join(args[1:])
|
||||||
|
r = post("/message", token=token, connectionid=conn_id, msgContent=content)
|
||||||
|
print("sent" if r.ok else f"error: {r.text}")
|
||||||
|
|
||||||
|
def cmd_history(args):
|
||||||
|
if not token:
|
||||||
|
print("not logged in")
|
||||||
|
return
|
||||||
|
if not args:
|
||||||
|
print("usage: /history <connectionid> [count]")
|
||||||
|
return
|
||||||
|
data = {"token": token, "connectionid": args[0]}
|
||||||
|
if len(args) > 1:
|
||||||
|
data["messages"] = args[1]
|
||||||
|
if len(args) > 2:
|
||||||
|
data["before"] = args[2]
|
||||||
|
r = requests.post(f"{BASE_URL}/messages", data=data)
|
||||||
|
if r.ok:
|
||||||
|
msgs = r.json() or []
|
||||||
|
if not msgs:
|
||||||
|
print("no messages")
|
||||||
|
for m in msgs:
|
||||||
|
print(fmt_msg(m))
|
||||||
|
else:
|
||||||
|
print(f"error: {r.text}")
|
||||||
|
|
||||||
|
def cmd_connections(args):
|
||||||
|
if not token:
|
||||||
|
print("not logged in")
|
||||||
|
return
|
||||||
|
r = post("/connections", token=token)
|
||||||
|
if r.ok:
|
||||||
|
for c in (r.json() or {}).values():
|
||||||
|
print(f" {c['id']} requestor={c['requestorId']} recipient={c['recipientId']} state={c['state']}")
|
||||||
|
else:
|
||||||
|
print(f"error: {r.text}")
|
||||||
|
|
||||||
|
COMMANDS = {
|
||||||
|
"/login": cmd_login,
|
||||||
|
"/send": cmd_send,
|
||||||
|
"/history": cmd_history,
|
||||||
|
"/connections": cmd_connections,
|
||||||
|
}
|
||||||
|
|
||||||
|
HELP = """
|
||||||
|
/login <user> <pass> – authenticate
|
||||||
|
/connections – list your connections
|
||||||
|
/send <connectionid> <message…> – send a DM
|
||||||
|
/history <connectionid> [count] [before] – fetch message history
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── websocket ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def receiver(ws):
|
async def receiver(ws):
|
||||||
async for message in ws:
|
async for raw in ws:
|
||||||
print(f"\n[SERVER] {message}", flush=True)
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
# pushed DM
|
||||||
|
if "content" in data and "sender" in data:
|
||||||
|
print(f"\n{fmt_msg(data)}", flush=True)
|
||||||
|
else:
|
||||||
|
print(f"\n[SERVER] {json.dumps(data, indent=2)}", flush=True)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"\n[SERVER] {raw}", flush=True)
|
||||||
|
|
||||||
async def sender(ws):
|
async def sender(ws):
|
||||||
while True:
|
while True:
|
||||||
@@ -42,12 +153,15 @@ async def run():
|
|||||||
|
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(URI) as ws:
|
async with websockets.connect(WS_URI) as ws:
|
||||||
print(f"Connected to {URI}")
|
print(f"connected to {WS_URI}")
|
||||||
if not input_thread_started:
|
if not input_thread_started:
|
||||||
print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n")
|
print("Alt+Enter = newline | Enter = send | /help | Ctrl+C = quit\n")
|
||||||
input_thread.start()
|
input_thread.start()
|
||||||
input_thread_started = True
|
input_thread_started = True
|
||||||
|
# re-auth after reconnect
|
||||||
|
if token:
|
||||||
|
await ws.send(json.dumps({"token": token}))
|
||||||
|
|
||||||
recv_task = asyncio.create_task(receiver(ws))
|
recv_task = asyncio.create_task(receiver(ws))
|
||||||
send_task = asyncio.create_task(sender(ws))
|
send_task = asyncio.create_task(sender(ws))
|
||||||
@@ -57,41 +171,50 @@ async def run():
|
|||||||
[recv_task, send_task, shutdown_task],
|
[recv_task, send_task, shutdown_task],
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in pending:
|
for t in pending:
|
||||||
t.cancel()
|
t.cancel()
|
||||||
|
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
# reconnect
|
|
||||||
send_queue = asyncio.Queue()
|
send_queue = asyncio.Queue()
|
||||||
print(f"\n[Disconnected] Reconnecting in {RETRY_DELAY}s...")
|
print(f"\n[disconnected] reconnecting in {RETRY_DELAY}s…")
|
||||||
|
|
||||||
except (OSError, websockets.exceptions.WebSocketException):
|
except (OSError, websockets.exceptions.WebSocketException):
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
return
|
return
|
||||||
print(f"[Waiting for server] Retrying in {RETRY_DELAY}s...", flush=True)
|
print(f"[waiting for server] retrying in {RETRY_DELAY}s…", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY)
|
await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# ── input loop ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def input_loop():
|
def input_loop():
|
||||||
session = PromptSession(
|
session = PromptSession(key_bindings=bindings, multiline=True)
|
||||||
key_bindings=bindings,
|
|
||||||
multiline=True,
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
text = session.prompt(">>> ")
|
text = session.prompt(">>> ").strip()
|
||||||
if text is not None:
|
if not text:
|
||||||
|
continue
|
||||||
|
if text == "/help":
|
||||||
|
print(HELP)
|
||||||
|
continue
|
||||||
|
parts = text.split()
|
||||||
|
cmd = parts[0]
|
||||||
|
if cmd in COMMANDS:
|
||||||
|
COMMANDS[cmd](parts[1:])
|
||||||
|
else:
|
||||||
|
# raw JSON passthrough
|
||||||
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
|
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop)
|
asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# ── main ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global loop, shutdown_event
|
global loop, shutdown_event
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|||||||
+32
-2
@@ -24,20 +24,50 @@ type User struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
|
Mu sync.RWMutex `json:"-"`
|
||||||
Id uuid.UUID `json:"id"`
|
Id uuid.UUID `json:"id"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
MessagesBuf [MaxDirectMsgCache]*Message `json:"-"`
|
MessagesBuff [MaxDirectMsgCache]*Message `json:"-"`
|
||||||
|
NextBuffIdx uint32 `json:"-"`
|
||||||
|
HaveOverflowed bool `json:"-"`
|
||||||
RequestorId uint32 `json:"requestorId"`
|
RequestorId uint32 `json:"requestorId"`
|
||||||
RecipientId uint32 `json:"recipientId"`
|
RecipientId uint32 `json:"recipientId"`
|
||||||
State ConnectionState.ConnectionState `json:"state"`
|
State ConnectionState.ConnectionState `json:"state"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Connection) AddMessageToBuff(message *Message) {
|
||||||
|
conn.Mu.Lock()
|
||||||
|
defer conn.Mu.Unlock()
|
||||||
|
|
||||||
|
conn.MessagesBuff[conn.NextBuffIdx%MaxDirectMsgCache] = message
|
||||||
|
conn.NextBuffIdx++
|
||||||
|
if conn.NextBuffIdx >= MaxDirectMsgCache {
|
||||||
|
conn.HaveOverflowed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSortedMessagesBuff returns arr, length
|
||||||
|
func (conn *Connection) GetSortedMessagesBuff() (*[MaxDirectMsgCache]*Message, uint32) {
|
||||||
|
conn.Mu.RLock()
|
||||||
|
defer conn.Mu.RUnlock()
|
||||||
|
|
||||||
|
if !conn.HaveOverflowed {
|
||||||
|
return &conn.MessagesBuff, conn.NextBuffIdx
|
||||||
|
}
|
||||||
|
|
||||||
|
sorted := new([MaxDirectMsgCache]*Message)
|
||||||
|
for i := uint32(0); i < MaxDirectMsgCache; i++ {
|
||||||
|
sorted[i] = conn.MessagesBuff[(conn.NextBuffIdx+i)%MaxDirectMsgCache]
|
||||||
|
}
|
||||||
|
return sorted, MaxDirectMsgCache
|
||||||
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Id uuid.UUID `json:"id"`
|
Id uuid.UUID `json:"id"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
Sender uint32 `json:"sender"`
|
Sender uint32 `json:"sender"`
|
||||||
Receiver uint32 `json:"receiver"`
|
Receiver uuid.UUID `json:"receiver"`
|
||||||
IsGroupMessage bool `json:"isGroupMessage"`
|
IsGroupMessage bool `json:"isGroupMessage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+23
-9
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -54,6 +55,26 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sendMessageStructCloseIfTimeout(user *User, message *Message) {
|
||||||
|
if user.WsConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
jsonMsg, err := json.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("json marshal error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = wsjson.Write(ctx, user.WsConn, jsonMsg)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("json write error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func sendMessageCloseIfTimeout(user *User, message *map[string]any) {
|
func sendMessageCloseIfTimeout(user *User, message *map[string]any) {
|
||||||
if user.WsConn == nil {
|
if user.WsConn == nil {
|
||||||
return
|
return
|
||||||
@@ -84,15 +105,8 @@ func sendToAllMessageCloseIfTimeout(message *map[string]any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WsSendToUser(to *User, message *Message) {
|
func WsMessageSendToUser(to *User, message *Message) {
|
||||||
var msg = map[string]any{
|
sendMessageStructCloseIfTimeout(to, message)
|
||||||
"type": WsMessageFrom.DirectMessage,
|
|
||||||
"id": message.Id,
|
|
||||||
"from": message.Sender,
|
|
||||||
"created": message.CreatedAt,
|
|
||||||
"content": message.Content,
|
|
||||||
}
|
|
||||||
sendMessageCloseIfTimeout(to, &msg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WsSendToGroup(group *Group, sender *User, message string) error {
|
func WsSendToGroup(group *Group, sender *User, message string) error {
|
||||||
|
|||||||
Reference in New Issue
Block a user