add cache for groups and clients, http sending message remains to be add

This commit is contained in:
2026-03-15 14:15:24 +01:00
parent 76fbb8b970
commit c97b21a39e
7 changed files with 166 additions and 89 deletions
+57
View File
@@ -0,0 +1,57 @@
package main
import (
"context"
"errors"
"sync"
)
var Groups map[uint32]ChatGroup
var ConnectedClients map[uint32]map[*Client]struct{}
func InitCache() {
groups, err := GetAllChatGroups(context.Background())
if err != nil {
panic(err)
}
for _, group := range groups {
Groups[group.Id] = group
}
}
func GetGroupById(groupId uint32) (*ChatGroup, error) {
group, ok := Groups[groupId]
if !ok {
return nil, errors.New("group not found")
}
return &group, nil
}
func AddOrUpdateGroupToCache(mu *sync.Mutex, group ChatGroup) {
mu.Lock()
defer mu.Unlock()
Groups[group.Id] = group
}
func RemoveGroupFromCache(mu *sync.Mutex, groupId uint32) {
mu.Lock()
defer mu.Unlock()
delete(Groups, groupId)
}
func AddOrUpdateConnectedClientToCache(mu *sync.Mutex, client *Client) {
mu.Lock()
defer mu.Unlock()
for _, groupId := range client.User.MemberGroupsId {
ConnectedClients[groupId][client] = struct{}{}
}
}
func RemoveConnectedClientFromCache(mu *sync.Mutex, client *Client) {
mu.Lock()
defer mu.Unlock()
for _, groupId := range client.User.MemberGroupsId {
delete(ConnectedClients[groupId], client)
}
}
+44 -4
View File
@@ -22,13 +22,13 @@ func InitDatabase(ctx context.Context) {
id SERIAL PRIMARY KEY,
name VARCHAR(20) UNIQUE NOT NULL,
pass_hash VARCHAR(60) NOT NULL,
color VARCHAR(3) DEFAULT NULL,
color VARCHAR(3) DEFAULT NULL
);
CREATE TABLE IF NOT EXISTS chat_groups (
id SERIAL PRIMARY KEY,
name VARCHAR(48) NOT NULL,
creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE SET NULL,
owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE SET NULL,
creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
enable_user_colors BOOLEAN NOT NULL DEFAULT true,
group_color VARCHAR(3),
created_at TIMESTAMP NOT NULL DEFAULT NOW()
@@ -43,7 +43,9 @@ func InitDatabase(ctx context.Context) {
if err != nil {
panic(err)
}
dbConnection = conn
InitCache()
}
func AddNewUser(ctx context.Context, user *User) (uint32, error) {
@@ -65,7 +67,7 @@ func AddNewUser(ctx context.Context, user *User) (uint32, error) {
}
user.Password = string(hashed)
}
if len(user.Color) != 1 && len(user.Color) != 3 {
if user.Color == ([3]byte{}) {
return 0, errors.New("color invalid")
}
err = dbConnection.QueryRow(ctx, `
@@ -86,6 +88,24 @@ func isPassValid(ctx context.Context, id uint32, plainPassword string) bool {
return bcrypt.CompareHashAndPassword([]byte(controlHash), []byte(plainPassword)) == nil
}
func GetAllUsers(ctx context.Context) ([]User, error) {
rows, err := dbConnection.Query(ctx, "SELECT id, name, color FROM users")
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
if err := rows.Scan(&user.Id, &user.Name, &user.Color); err != nil {
return nil, err
}
users = append(users, user)
}
return users, rows.Err()
}
func GetUserDataById(ctx context.Context, id uint32) (User, error) {
var user User
err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id).
@@ -120,9 +140,29 @@ func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint3
VALUES ($1, $2, $3, $4)
RETURNING id
`, group.Name, group.CreatorId, group.OwnerId, group.CreatedAt).Scan(&id)
AddOrUpdateGroupToCache(&mu, *group)
return id, err
}
func GetAllChatGroups(ctx context.Context) ([]ChatGroup, error) {
rows, err := dbConnection.Query(ctx, "SELECT id, name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups")
if err != nil {
return nil, err
}
defer rows.Close()
var groups []ChatGroup
for rows.Next() {
var group ChatGroup
if err := rows.Scan(&group.Id, &group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color, &group.CreatedAt); err != nil {
return nil, err
}
groups = append(groups, group)
}
return groups, rows.Err()
}
func GetChatGroupWithoutMembers(ctx context.Context, id uint32) (ChatGroup, error) {
var group ChatGroup
err := dbConnection.QueryRow(ctx, `SELECT name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups WHERE id = $1`,
+34 -39
View File
@@ -68,7 +68,7 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) {
password := request.FormValue("password")
respondBadLogin := func() {
http.Error(response, "bad login", http.StatusConflict)
http.Error(response, "bad login", http.StatusUnauthorized)
}
if len(username) < 2 {
@@ -101,46 +101,11 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
}
var anyAuthDone bool = false
var user User
ctx := request.Context()
username := request.FormValue("username")
password := request.FormValue("password")
respondBadLogin := func() {
http.Error(response, "bad login", http.StatusConflict)
}
if len(password) > 0 {
if len(username) < 2 {
http.Error(response, "no or too short nick", http.StatusBadRequest)
return
}
tmp, err := GetUserDataByName(ctx, username)
user, err := GetUserFromToken(request.FormValue("token"))
if err != nil {
respondBadLogin()
return
}
if bcrypt.CompareHashAndPassword([]byte(tmp.Password), []byte(password)) != nil {
respondBadLogin()
return
}
user = tmp
anyAuthDone = true
} else if token := request.FormValue("token"); len(token) > 0 {
tmp, err := GetUserFromToken(token)
if err != nil {
respondBadLogin()
return
}
user = tmp
anyAuthDone = true
}
if !anyAuthDone {
http.Error(response, "no login or token", http.StatusBadRequest)
http.Error(response, "invalid token", http.StatusUnauthorized)
return
}
@@ -150,7 +115,7 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) {
return
}
_, err := CreateChatGroupWithoutMembers(ctx, &ChatGroup{
_, err = CreateChatGroupWithoutMembers(ctx, &ChatGroup{
Name: groupName,
CreatorId: user.Id,
OwnerId: user.Id,
@@ -163,3 +128,33 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) {
}
response.WriteHeader(http.StatusCreated)
}
func SendMessageHandler(response http.ResponseWriter, request *http.Request) {
groupId := request.PathValue("groupid")
if groupId == "" {
http.Error(response, "no group id", http.StatusBadRequest)
}
var user User
var err error
token := request.FormValue("token")
if token == "" {
http.Error(response, "no token", http.StatusBadRequest)
return
}
if user, err = GetUserFromToken(token); err != nil || user == nil {
http.Error(response, "invalid token", http.StatusUnauthorized)
return
}
content := request.FormValue("content")
if content == "" {
http.Error(response, "no content", http.StatusBadRequest)
return
}
var isInGroup bool
for _, groupId := range user.MemberGroupsId {
}
}
+7 -3
View File
@@ -10,10 +10,12 @@ func main() {
InitDatabase(context.Background())
srv := &wsServer{
OnOpen: func(c *Client) {
AddOrUpdateConnectedClientToCache(&mu, c)
log.Println("client connected")
},
OnClose: func(c *Client, err error) {
log.Println("client disconnected:", err)
RemoveConnectedClientFromCache(&mu, c)
},
OnMessage: func(c *Client, msg map[string]any) {
log.Printf("received: %v\n", msg)
@@ -27,8 +29,10 @@ func main() {
http.Handle("/ws", srv)
log.Println("server listening on :8080")
http.HandleFunc("POST /register", RegisterHandler)
http.HandleFunc("POST /login", LoginHandler)
http.HandleFunc("POST /create/group", CreateGroupHandler)
http.HandleFunc("POST /new/account", RegisterHandler)
http.HandleFunc("POST /new/token", LoginHandler)
http.HandleFunc("POST /new/group", CreateGroupHandler)
http.HandleFunc("POST /new/messageto/{groupid}", SendMessageHandler)
log.Fatal(http.ListenAndServe(":8080", nil))
}
+2 -1
View File
@@ -7,9 +7,10 @@ import (
)
type User struct {
MemberGroupsId []uint32
Name string
Password string
Color string
Color [3]byte
Id uint32
IsPasswordHashed bool
}
+4 -3
View File
@@ -7,7 +7,6 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
_ "github.com/golang-jwt/jwt/v5"
)
var secretKey = []byte("replace-with-env-variable")
@@ -22,7 +21,7 @@ func GetToken(user *User) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
UserClaims{
Name: user.Name,
Color: user.Color,
Color: string(user.Color[:]),
RegisteredClaims: jwt.RegisteredClaims{
Subject: strconv.Itoa(int(user.Id)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
@@ -54,9 +53,11 @@ func GetUserFromToken(tokenString string) (User, error) {
return User{}, fmt.Errorf("invalid subject: %w", err)
}
var color [3]byte
copy(color[:], claims.Color)
return User{
Id: uint32(id),
Name: claims.Name,
Color: claims.Color,
Color: color,
}, nil
}
+17 -38
View File
@@ -2,6 +2,7 @@ package main
import (
"context"
"errors"
"log"
"net/http"
"sync"
@@ -18,7 +19,6 @@ type wsServer struct {
}
var (
clients []*Client
mu sync.Mutex
)
@@ -33,9 +33,6 @@ func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.R
defer conn.CloseNow()
client := &Client{conn: conn}
mu.Lock()
clients = append(clients, client)
mu.Unlock()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -61,37 +58,27 @@ func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.R
s.OnClose(client, readErr)
}
mu.Lock()
for i, c := range clients {
if c == client {
clients[i] = clients[len(clients)-1]
clients = clients[:len(clients)-1]
break
}
}
mu.Unlock()
conn.Close(websocket.StatusNormalClosure, "done")
}
func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) {
func sendAndCloseIfFails(conn *websocket.Conn, message *map[string]any) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := wsjson.Write(ctx, conn, message); err != nil {
conn.Close(websocket.StatusGoingAway, "Write error")
}
}
func sendToAllExceptAndCloseIfFails(client *Client, message map[string]any) {
for _, other := range clients {
if other != client {
sendAndCloseIfFails(other.conn, message)
}
}
func sendToGroup(id uint32, excludedUserId uint32, message *map[string]any) error {
if _, ok := Groups[id]; !ok {
return errors.New("Group Not Found")
}
func sendToGroup() {
for client := range ConnectedClients[id] {
if client.User.Id != excludedUserId {
sendAndCloseIfFails(client.conn, message)
}
}
return nil
}
func handleUnauthenticatedMessage(client *Client, msg map[string]any) {
@@ -102,23 +89,15 @@ func handleUnauthenticatedMessage(client *Client, msg map[string]any) {
return
}
client.User = &user
sendAndCloseIfFails(client.conn, map[string]any{
m := map[string]any{
"authAs": user.Name,
})
}
sendAndCloseIfFails(client.conn, &m)
AddOrUpdateConnectedClientToCache(&mu, client)
log.Println("New User authenticated as: " + user.Name)
}
func handleAuthenticatedMessage(client *Client, msg map[string]any) {
message := msg["message"].(string)
if message == "" {
sendAndCloseIfFails(client.conn, map[string]any{
"error": "no message",
})
return
}
sendToAllExceptAndCloseIfFails(client, map[string]any{
"username": client.User.Name,
"message": message,
})
m := map[string]any{"temporary": "unauthorized"}
sendAndCloseIfFails(client.conn, &m)
}