diff --git a/database.go b/database.go index 3a0e525..b5b6cd0 100644 --- a/database.go +++ b/database.go @@ -89,21 +89,21 @@ func DbUserDelete(ctx context.Context, id uint32) error { return err } -func DbGetUserByName(ctx context.Context, user *User) error { +func DbUserGetByName(ctx context.Context, user *User) error { err := dbConn.QueryRow(ctx, ` SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE name = $1 `, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt) return err } -func DbGetUserById(ctx context.Context, user *User) error { +func DbUserGetById(ctx context.Context, user *User) error { err := dbConn.QueryRow(ctx, ` SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE id = $1 `, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt) return err } -func DbGetUserGroups(ctx context.Context, user *User) error { +func DbUserGetGroups(ctx context.Context, user *User) error { rows, err := dbConn.Query(ctx, ` SELECT group_id FROM chat_group_members WHERE user_id = $1 `, user.Id) @@ -123,41 +123,39 @@ func DbGetUserGroups(ctx context.Context, user *User) error { return rows.Err() } -func DbGetUserConnections(ctx context.Context, user *User) error { +func DbUserGetConnections(ctx context.Context, user *User) error { rows, err := dbConn.Query(ctx, ` - SELECT requestor_id, recipient_id, is_accepted, created_at FROM user_connections WHERE requestor_id = $1 or recipient_id = $1 - `) + SELECT + CASE WHEN requestor_id = $1 THEN recipient_id ELSE requestor_id END AS other_id, + requestor_id = $1 AS is_from_user, + is_accepted, + created_at + FROM user_connections + WHERE requestor_id = $1 OR recipient_id = $1 + `, user.Id) if err != nil { return err } + user.Connections = make(map[uint32]*Connection) defer rows.Close() for rows.Next() { var ( - requestorId uint32 - recipientId uint32 - isAccepted bool - createdAt time.Time + otherId uint32 + isFromUser bool + isAccepted bool + createdAt time.Time ) - err = rows.Scan(&requestorId, &recipientId, &isAccepted, &createdAt) + err = rows.Scan(&otherId, &isFromUser, &isAccepted, &createdAt) if err != nil { return err } - if requestorId == user.Id { - user.Connections[recipientId] = &Connection{ - CreatedAt: createdAt, - With: recipientId, - IsFromUser: true, - IsAccepted: isAccepted, - } - } else { - user.Connections[requestorId] = &Connection{ - CreatedAt: createdAt, - With: requestorId, - IsFromUser: false, - IsAccepted: isAccepted, - } + user.Connections[otherId] = &Connection{ + CreatedAt: createdAt, + With: otherId, + IsFromUser: isFromUser, + IsAccepted: isAccepted, } } @@ -210,14 +208,14 @@ func DbGroupDelete(ctx context.Context, group *Group) error { return err } -func DbGetGroupById(ctx context.Context, group *Group) error { +func DbGroupGetById(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1 `, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt) return err } -func DbGetGroupMembers(ctx context.Context, group *Group) error { +func DbGroupGetMembers(ctx context.Context, group *Group) error { rows, err := dbConn.Query(ctx, ` SELECT user_id FROM chat_group_members WHERE group_id = $1 `, group.Id) diff --git a/enums.go b/enums.go index aaf8639..d5969f7 100644 --- a/enums.go +++ b/enums.go @@ -1,8 +1,9 @@ package main -type WsServerResponse uint8 +type WsMessageToUserFrom uint8 const ( - BadMessage WsServerResponse = iota - InvalidCredentials + Server_ WsMessageToUserFrom = iota + DirectMessage_ + Group_ ) diff --git a/http.go b/http.go index 7434ea3..0dacf69 100644 --- a/http.go +++ b/http.go @@ -31,13 +31,13 @@ func getUser(ctx context.Context, token string) (*User, error) { user, err := CacheGetUserById(userId) if err != nil { user = &User{Id: userId} - if err = DbGetUserById(ctx, user); err != nil { + if err = DbUserGetById(ctx, user); err != nil { return nil, err } - if err = DbGetUserGroups(ctx, user); err != nil { + if err = DbUserGetGroups(ctx, user); err != nil { return nil, err } - if err = DbGetUserConnections(ctx, user); err != nil { + if err = DbUserGetConnections(ctx, user); err != nil { return nil, err } CacheSaveUser(user) @@ -50,10 +50,10 @@ func getGroup(ctx context.Context, groupId uint32) (*Group, error) { group, err := CacheGetGroup(groupId) if err != nil { group = &Group{Id: groupId} - if err = DbGetGroupById(ctx, group); err != nil { + if err = DbGroupGetById(ctx, group); err != nil { return nil, err } - if err = DbGetGroupMembers(ctx, group); err != nil { + if err = DbGroupGetMembers(ctx, group); err != nil { return nil, err } CacheSaveGroup(group) @@ -214,6 +214,33 @@ func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request) return } + targetId, err := ConvertStringUint32(request.FormValue("recipientid")) + if err != nil { + http.Error(response, "invalid recipient id", http.StatusBadRequest) + return + } + + target, err := CacheGetUserById(targetId) + if err != nil { + target = &User{Id: targetId} + err = DbUserGetById(ctx, target) + if err != nil { + http.Error(response, "invalid recipient id", http.StatusBadRequest) + } + } + + if user.Connections[target.Id] == nil { + http.Error(response, "invalid recipient id", http.StatusBadRequest) + return + } + + message := request.FormValue("message") + if message == "" { + http.Error(response, "empty message", http.StatusBadRequest) + return + } + + WsSendToUser(user, target, message) } func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) { @@ -243,15 +270,15 @@ func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) { user, err = CacheGetUserByName(username) if err != nil { user = &User{Name: username} - if err = DbGetUserByName(ctx, user); err != nil { + if err = DbUserGetByName(ctx, user); err != nil { http.Error(response, "bad login1", http.StatusUnauthorized) return } - if err = DbGetUserGroups(ctx, user); err != nil { + if err = DbUserGetGroups(ctx, user); err != nil { http.Error(response, "bad login1", http.StatusUnauthorized) return } - if err = DbGetUserConnections(ctx, user); err != nil { + if err = DbUserGetConnections(ctx, user); err != nil { http.Error(response, "bad login1", http.StatusUnauthorized) return } @@ -309,7 +336,7 @@ func HttpHandeGroupCreate(response http.ResponseWriter, request *http.Request) { group.EnableUserColors = true } - err = DbGroupSaveWithoutUsers(ctx, &group) + err = DbGroupSave(ctx, &group) if err != nil { http.Error(response, err.Error(), http.StatusInternalServerError) return @@ -486,7 +513,7 @@ func HttpHandleGroupChangeOwner(response http.ResponseWriter, request *http.Requ newOwner, err := CacheGetUserByName(newOwnerName) if err != nil { newOwner = &User{Name: newOwnerName} - err = DbGetUserByName(ctx, newOwner) + err = DbUserGetByName(ctx, newOwner) if err != nil { http.Error(response, "user not in group", http.StatusBadRequest) return @@ -517,30 +544,33 @@ func HttpHandleGroupMessage(response http.ResponseWriter, request *http.Request) } ctx := request.Context() + user, err := getUser(ctx, request.FormValue("token")) if err != nil { - http.Error(response, "invalid token", http.StatusUnauthorized) + http.Error(response, "invalid token", http.StatusBadRequest) + return + } + groupIdStr := request.FormValue("groupid") + groupId, err := ConvertStringUint32(groupIdStr) + if err != nil { + http.Error(response, "no such group", http.StatusUnauthorized) return } - targetStr := request.FormValue("subject") - if targetStr == "" { - http.Error(response, "invalid subject", http.StatusBadRequest) - return - } - targetId, err := ConvertStringUint32(targetStr) - if err != nil { - http.Error(response, "invalid subject", http.StatusBadRequest) - return - } + group, err := getGroup(ctx, groupId) content := request.FormValue("content") if content == "" { - http.Error(response, "invalid content", http.StatusBadRequest) + http.Error(response, "empty message", http.StatusBadRequest) return } - err = WsSendToGroup(ctx, targetId, user.Id, content) + _, ok := group.Users[user.Id] + if !ok { + http.Error(response, "no such group", http.StatusUnauthorized) + } + + err = WsSendToGroup(group, user, content) if err != nil { http.Error(response, err.Error(), http.StatusBadRequest) return diff --git a/todo.txt b/todo.txt index 74c972e..170a88a 100644 --- a/todo.txt +++ b/todo.txt @@ -1,6 +1,6 @@ +connection request/accept direct messaging -friend list -who can send messages to who +who can send messages to who (configurable) chat history diff --git a/wsServer.go b/wsServer.go index b8387b6..fe03862 100644 --- a/wsServer.go +++ b/wsServer.go @@ -77,21 +77,16 @@ func sendToAllMessageCloseIfTimeout(message *map[string]any) { } } -func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message string) error { - group, err := CacheGetGroup(groupId) - if err != nil { - return errors.New("group invalid") - } - - sender, err := CacheGetUserById(senderId) - if err != nil { - sender = &User{Id: senderId} - err = DbUserSetById(ctx, sender) - if err != nil { - return errors.New("non existing sender") - } +func WsSendToUser(from *User, to *User, message string) { + var msg = map[string]any{ + "type": WsMessageToUserFrom(DirectMessage_), + "from": from.Id, + "content": message, } + sendMessageCloseIfTimeout(from, &msg) +} +func WsSendToGroup(group *Group, sender *User, message string) error { for groupUserId := range group.Users { groupUser, err := CacheGetUserById(groupUserId) if err != nil || groupUser.Id == sender.Id { @@ -99,9 +94,9 @@ func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message } var msg = map[string]any{ - "from": "group", - "group": group.Id, - "sender": sender.Name, + "type": WsMessageToUserFrom(Group_), + "from": group.Id, + "sender": sender.Id, "content": message, } sendMessageCloseIfTimeout(groupUser, &msg) @@ -118,7 +113,7 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage * token, ok := (*userMessage)["token"].(string) if !ok { var msg = map[string]any{ - "from": "server", + "type": WsMessageToUserFrom(Server_), "error": "no token in message", } sendMessageCloseIfTimeout(user, &msg) @@ -128,7 +123,7 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage * userId, err := TokenValidateGetId(token) if err != nil { var msg = map[string]any{ - "from": "server", + "type": WsMessageToUserFrom(Server_), "error": "invalid token", } sendMessageCloseIfTimeout(user, &msg) @@ -137,28 +132,12 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage * userFromCache, err := CacheGetUserById(userId) if err != nil { - dbUser := &User{Id: userId} - err = DbUserSetByIdWithoutGroupsConnections(ctx, dbUser) - if err != nil { - var msg = map[string]any{ - "from": "server", - "error": "invalid user data", - } - sendMessageCloseIfTimeout(user, &msg) - return false + var msg = map[string]any{ + "type": WsMessageToUserFrom(Server_), + "error": "user not found", } - err = DbUserSetGroups(ctx, dbUser) - if err != nil { - var msg = map[string]any{ - "from": "server", - "error": "invalid user data", - } - sendMessageCloseIfTimeout(user, &msg) - return false - } - dbUser.WsConn = user.WsConn - CacheSaveUser(dbUser) - userFromCache = dbUser + sendMessageCloseIfTimeout(user, &msg) + return false } userFromCache.WsConn = user.WsConn @@ -169,10 +148,19 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage * if err != nil { dbGroup := &Group{Id: groupId} - err = DbGroupSetById(ctx, dbGroup) + err = DbGroupGetById(ctx, dbGroup) if err != nil { var msg = map[string]any{ - "from": "server", + "type": "server", + "error": "invalid user data", + } + sendMessageCloseIfTimeout(user, &msg) + return false + } + err = DbGroupGetMembers(ctx, dbGroup) + if err != nil { + var msg = map[string]any{ + "type": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg)