diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..4219dfa --- /dev/null +++ b/cache.go @@ -0,0 +1,57 @@ +package main + +import ( + "context" + "errors" + "sync" +) + +var Groups map[uint32]ChatGroup +var ConnectedClients map[uint32]map[*Client]struct{} + +func InitCache() { + groups, err := GetAllChatGroups(context.Background()) + if err != nil { + panic(err) + } + + for _, group := range groups { + Groups[group.Id] = group + } +} + +func GetGroupById(groupId uint32) (*ChatGroup, error) { + group, ok := Groups[groupId] + if !ok { + return nil, errors.New("group not found") + } + return &group, nil +} + +func AddOrUpdateGroupToCache(mu *sync.Mutex, group ChatGroup) { + mu.Lock() + defer mu.Unlock() + Groups[group.Id] = group +} + +func RemoveGroupFromCache(mu *sync.Mutex, groupId uint32) { + mu.Lock() + defer mu.Unlock() + delete(Groups, groupId) +} + +func AddOrUpdateConnectedClientToCache(mu *sync.Mutex, client *Client) { + mu.Lock() + defer mu.Unlock() + for _, groupId := range client.User.MemberGroupsId { + ConnectedClients[groupId][client] = struct{}{} + } +} + +func RemoveConnectedClientFromCache(mu *sync.Mutex, client *Client) { + mu.Lock() + defer mu.Unlock() + for _, groupId := range client.User.MemberGroupsId { + delete(ConnectedClients[groupId], client) + } +} diff --git a/database.go b/database.go index ab58a8f..777e09e 100644 --- a/database.go +++ b/database.go @@ -22,13 +22,13 @@ func InitDatabase(ctx context.Context) { id SERIAL PRIMARY KEY, name VARCHAR(20) UNIQUE NOT NULL, pass_hash VARCHAR(60) NOT NULL, - color VARCHAR(3) DEFAULT NULL, + color VARCHAR(3) DEFAULT NULL ); CREATE TABLE IF NOT EXISTS chat_groups ( id SERIAL PRIMARY KEY, name VARCHAR(48) NOT NULL, - creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE SET NULL, - owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE SET NULL, + creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, enable_user_colors BOOLEAN NOT NULL DEFAULT true, group_color VARCHAR(3), created_at TIMESTAMP NOT NULL DEFAULT NOW() @@ -43,7 +43,9 @@ func InitDatabase(ctx context.Context) { if err != nil { panic(err) } + dbConnection = conn + InitCache() } func AddNewUser(ctx context.Context, user *User) (uint32, error) { @@ -65,7 +67,7 @@ func AddNewUser(ctx context.Context, user *User) (uint32, error) { } user.Password = string(hashed) } - if len(user.Color) != 1 && len(user.Color) != 3 { + if user.Color == ([3]byte{}) { return 0, errors.New("color invalid") } err = dbConnection.QueryRow(ctx, ` @@ -86,6 +88,24 @@ func isPassValid(ctx context.Context, id uint32, plainPassword string) bool { return bcrypt.CompareHashAndPassword([]byte(controlHash), []byte(plainPassword)) == nil } +func GetAllUsers(ctx context.Context) ([]User, error) { + rows, err := dbConnection.Query(ctx, "SELECT id, name, color FROM users") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + if err := rows.Scan(&user.Id, &user.Name, &user.Color); err != nil { + return nil, err + } + users = append(users, user) + } + return users, rows.Err() +} + func GetUserDataById(ctx context.Context, id uint32) (User, error) { var user User err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). @@ -120,9 +140,29 @@ func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint3 VALUES ($1, $2, $3, $4) RETURNING id `, group.Name, group.CreatorId, group.OwnerId, group.CreatedAt).Scan(&id) + + AddOrUpdateGroupToCache(&mu, *group) return id, err } +func GetAllChatGroups(ctx context.Context) ([]ChatGroup, error) { + rows, err := dbConnection.Query(ctx, "SELECT id, name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups") + if err != nil { + return nil, err + } + defer rows.Close() + + var groups []ChatGroup + for rows.Next() { + var group ChatGroup + if err := rows.Scan(&group.Id, &group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color, &group.CreatedAt); err != nil { + return nil, err + } + groups = append(groups, group) + } + return groups, rows.Err() +} + func GetChatGroupWithoutMembers(ctx context.Context, id uint32) (ChatGroup, error) { var group ChatGroup err := dbConnection.QueryRow(ctx, `SELECT name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups WHERE id = $1`, diff --git a/http.go b/http.go index af0af2f..066694e 100644 --- a/http.go +++ b/http.go @@ -68,7 +68,7 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) { password := request.FormValue("password") respondBadLogin := func() { - http.Error(response, "bad login", http.StatusConflict) + http.Error(response, "bad login", http.StatusUnauthorized) } if len(username) < 2 { @@ -101,46 +101,11 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) { if !isMethodAllowed(&response, request) { return } - var anyAuthDone bool = false - var user User ctx := request.Context() - username := request.FormValue("username") - password := request.FormValue("password") - respondBadLogin := func() { - http.Error(response, "bad login", http.StatusConflict) - } - - if len(password) > 0 { - if len(username) < 2 { - http.Error(response, "no or too short nick", http.StatusBadRequest) - return - } - - tmp, err := GetUserDataByName(ctx, username) - if err != nil { - respondBadLogin() - return - } - - if bcrypt.CompareHashAndPassword([]byte(tmp.Password), []byte(password)) != nil { - respondBadLogin() - return - } - user = tmp - anyAuthDone = true - } else if token := request.FormValue("token"); len(token) > 0 { - tmp, err := GetUserFromToken(token) - if err != nil { - respondBadLogin() - return - } - user = tmp - anyAuthDone = true - } - - if !anyAuthDone { - http.Error(response, "no login or token", http.StatusBadRequest) + user, err := GetUserFromToken(request.FormValue("token")) + if err != nil { + http.Error(response, "invalid token", http.StatusUnauthorized) return } @@ -150,7 +115,7 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) { return } - _, err := CreateChatGroupWithoutMembers(ctx, &ChatGroup{ + _, err = CreateChatGroupWithoutMembers(ctx, &ChatGroup{ Name: groupName, CreatorId: user.Id, OwnerId: user.Id, @@ -163,3 +128,33 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) { } response.WriteHeader(http.StatusCreated) } + +func SendMessageHandler(response http.ResponseWriter, request *http.Request) { + groupId := request.PathValue("groupid") + if groupId == "" { + http.Error(response, "no group id", http.StatusBadRequest) + } + + var user User + var err error + token := request.FormValue("token") + if token == "" { + http.Error(response, "no token", http.StatusBadRequest) + return + } + if user, err = GetUserFromToken(token); err != nil || user == nil { + http.Error(response, "invalid token", http.StatusUnauthorized) + return + } + + content := request.FormValue("content") + if content == "" { + http.Error(response, "no content", http.StatusBadRequest) + return + } + + var isInGroup bool + for _, groupId := range user.MemberGroupsId { + + } +} diff --git a/main.go b/main.go index 9460254..a34e5d4 100644 --- a/main.go +++ b/main.go @@ -10,10 +10,12 @@ func main() { InitDatabase(context.Background()) srv := &wsServer{ OnOpen: func(c *Client) { + AddOrUpdateConnectedClientToCache(&mu, c) log.Println("client connected") }, OnClose: func(c *Client, err error) { log.Println("client disconnected:", err) + RemoveConnectedClientFromCache(&mu, c) }, OnMessage: func(c *Client, msg map[string]any) { log.Printf("received: %v\n", msg) @@ -27,8 +29,10 @@ func main() { http.Handle("/ws", srv) log.Println("server listening on :8080") - http.HandleFunc("POST /register", RegisterHandler) - http.HandleFunc("POST /login", LoginHandler) - http.HandleFunc("POST /create/group", CreateGroupHandler) + http.HandleFunc("POST /new/account", RegisterHandler) + http.HandleFunc("POST /new/token", LoginHandler) + http.HandleFunc("POST /new/group", CreateGroupHandler) + http.HandleFunc("POST /new/messageto/{groupid}", SendMessageHandler) + log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/structures.go b/structures.go index 9a80eb6..08fde35 100644 --- a/structures.go +++ b/structures.go @@ -7,9 +7,10 @@ import ( ) type User struct { + MemberGroupsId []uint32 Name string Password string - Color string + Color [3]byte Id uint32 IsPasswordHashed bool } diff --git a/tokens.go b/tokens.go index c8fea1c..93bb2a0 100644 --- a/tokens.go +++ b/tokens.go @@ -7,7 +7,6 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - _ "github.com/golang-jwt/jwt/v5" ) var secretKey = []byte("replace-with-env-variable") @@ -22,7 +21,7 @@ func GetToken(user *User) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, UserClaims{ Name: user.Name, - Color: user.Color, + Color: string(user.Color[:]), RegisteredClaims: jwt.RegisteredClaims{ Subject: strconv.Itoa(int(user.Id)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), @@ -54,9 +53,11 @@ func GetUserFromToken(tokenString string) (User, error) { return User{}, fmt.Errorf("invalid subject: %w", err) } + var color [3]byte + copy(color[:], claims.Color) return User{ Id: uint32(id), Name: claims.Name, - Color: claims.Color, + Color: color, }, nil } diff --git a/wsServer.go b/wsServer.go index 4a59014..52a360c 100644 --- a/wsServer.go +++ b/wsServer.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "log" "net/http" "sync" @@ -18,8 +19,7 @@ type wsServer struct { } var ( - clients []*Client - mu sync.Mutex + mu sync.Mutex ) func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { @@ -33,9 +33,6 @@ func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.R defer conn.CloseNow() client := &Client{conn: conn} - mu.Lock() - clients = append(clients, client) - mu.Unlock() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -61,37 +58,27 @@ func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.R s.OnClose(client, readErr) } - mu.Lock() - for i, c := range clients { - if c == client { - clients[i] = clients[len(clients)-1] - clients = clients[:len(clients)-1] - break - } - } - mu.Unlock() - conn.Close(websocket.StatusNormalClosure, "done") } -func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { +func sendAndCloseIfFails(conn *websocket.Conn, message *map[string]any) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := wsjson.Write(ctx, conn, message); err != nil { conn.Close(websocket.StatusGoingAway, "Write error") } } +func sendToGroup(id uint32, excludedUserId uint32, message *map[string]any) error { + if _, ok := Groups[id]; !ok { + return errors.New("Group Not Found") + } -func sendToAllExceptAndCloseIfFails(client *Client, message map[string]any) { - for _, other := range clients { - if other != client { - sendAndCloseIfFails(other.conn, message) + for client := range ConnectedClients[id] { + if client.User.Id != excludedUserId { + sendAndCloseIfFails(client.conn, message) } } -} - -func sendToGroup() { - + return nil } func handleUnauthenticatedMessage(client *Client, msg map[string]any) { @@ -102,23 +89,15 @@ func handleUnauthenticatedMessage(client *Client, msg map[string]any) { return } client.User = &user - sendAndCloseIfFails(client.conn, map[string]any{ + m := map[string]any{ "authAs": user.Name, - }) + } + sendAndCloseIfFails(client.conn, &m) + AddOrUpdateConnectedClientToCache(&mu, client) log.Println("New User authenticated as: " + user.Name) } func handleAuthenticatedMessage(client *Client, msg map[string]any) { - message := msg["message"].(string) - if message == "" { - sendAndCloseIfFails(client.conn, map[string]any{ - "error": "no message", - }) - return - } - - sendToAllExceptAndCloseIfFails(client, map[string]any{ - "username": client.User.Name, - "message": message, - }) + m := map[string]any{"temporary": "unauthorized"} + sendAndCloseIfFails(client.conn, &m) }