fix db function naming, connections query, and nil map init

This commit is contained in:
2026-04-04 16:45:01 +02:00
parent 55095d5f02
commit e30a3077b1
5 changed files with 113 additions and 96 deletions
+25 -27
View File
@@ -89,21 +89,21 @@ func DbUserDelete(ctx context.Context, id uint32) error {
return err return err
} }
func DbGetUserByName(ctx context.Context, user *User) error { func DbUserGetByName(ctx context.Context, user *User) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE name = $1 SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE name = $1
`, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt) `, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt)
return err return err
} }
func DbGetUserById(ctx context.Context, user *User) error { func DbUserGetById(ctx context.Context, user *User) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE id = $1 SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM users WHERE id = $1
`, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt) `, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt)
return err return err
} }
func DbGetUserGroups(ctx context.Context, user *User) error { func DbUserGetGroups(ctx context.Context, user *User) error {
rows, err := dbConn.Query(ctx, ` rows, err := dbConn.Query(ctx, `
SELECT group_id FROM chat_group_members WHERE user_id = $1 SELECT group_id FROM chat_group_members WHERE user_id = $1
`, user.Id) `, user.Id)
@@ -123,41 +123,39 @@ func DbGetUserGroups(ctx context.Context, user *User) error {
return rows.Err() return rows.Err()
} }
func DbGetUserConnections(ctx context.Context, user *User) error { func DbUserGetConnections(ctx context.Context, user *User) error {
rows, err := dbConn.Query(ctx, ` rows, err := dbConn.Query(ctx, `
SELECT requestor_id, recipient_id, is_accepted, created_at FROM user_connections WHERE requestor_id = $1 or recipient_id = $1 SELECT
`) CASE WHEN requestor_id = $1 THEN recipient_id ELSE requestor_id END AS other_id,
requestor_id = $1 AS is_from_user,
is_accepted,
created_at
FROM user_connections
WHERE requestor_id = $1 OR recipient_id = $1
`, user.Id)
if err != nil { if err != nil {
return err return err
} }
user.Connections = make(map[uint32]*Connection)
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var ( var (
requestorId uint32 otherId uint32
recipientId uint32 isFromUser bool
isAccepted bool isAccepted bool
createdAt time.Time createdAt time.Time
) )
err = rows.Scan(&requestorId, &recipientId, &isAccepted, &createdAt) err = rows.Scan(&otherId, &isFromUser, &isAccepted, &createdAt)
if err != nil { if err != nil {
return err return err
} }
if requestorId == user.Id { user.Connections[otherId] = &Connection{
user.Connections[recipientId] = &Connection{ CreatedAt: createdAt,
CreatedAt: createdAt, With: otherId,
With: recipientId, IsFromUser: isFromUser,
IsFromUser: true, IsAccepted: isAccepted,
IsAccepted: isAccepted,
}
} else {
user.Connections[requestorId] = &Connection{
CreatedAt: createdAt,
With: requestorId,
IsFromUser: false,
IsAccepted: isAccepted,
}
} }
} }
@@ -210,14 +208,14 @@ func DbGroupDelete(ctx context.Context, group *Group) error {
return err return err
} }
func DbGetGroupById(ctx context.Context, group *Group) error { func DbGroupGetById(ctx context.Context, group *Group) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1 SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1
`, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt) `, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt)
return err return err
} }
func DbGetGroupMembers(ctx context.Context, group *Group) error { func DbGroupGetMembers(ctx context.Context, group *Group) error {
rows, err := dbConn.Query(ctx, ` rows, err := dbConn.Query(ctx, `
SELECT user_id FROM chat_group_members WHERE group_id = $1 SELECT user_id FROM chat_group_members WHERE group_id = $1
`, group.Id) `, group.Id)
+4 -3
View File
@@ -1,8 +1,9 @@
package main package main
type WsServerResponse uint8 type WsMessageToUserFrom uint8
const ( const (
BadMessage WsServerResponse = iota Server_ WsMessageToUserFrom = iota
InvalidCredentials DirectMessage_
Group_
) )
+53 -23
View File
@@ -31,13 +31,13 @@ func getUser(ctx context.Context, token string) (*User, error) {
user, err := CacheGetUserById(userId) user, err := CacheGetUserById(userId)
if err != nil { if err != nil {
user = &User{Id: userId} user = &User{Id: userId}
if err = DbGetUserById(ctx, user); err != nil { if err = DbUserGetById(ctx, user); err != nil {
return nil, err return nil, err
} }
if err = DbGetUserGroups(ctx, user); err != nil { if err = DbUserGetGroups(ctx, user); err != nil {
return nil, err return nil, err
} }
if err = DbGetUserConnections(ctx, user); err != nil { if err = DbUserGetConnections(ctx, user); err != nil {
return nil, err return nil, err
} }
CacheSaveUser(user) CacheSaveUser(user)
@@ -50,10 +50,10 @@ func getGroup(ctx context.Context, groupId uint32) (*Group, error) {
group, err := CacheGetGroup(groupId) group, err := CacheGetGroup(groupId)
if err != nil { if err != nil {
group = &Group{Id: groupId} group = &Group{Id: groupId}
if err = DbGetGroupById(ctx, group); err != nil { if err = DbGroupGetById(ctx, group); err != nil {
return nil, err return nil, err
} }
if err = DbGetGroupMembers(ctx, group); err != nil { if err = DbGroupGetMembers(ctx, group); err != nil {
return nil, err return nil, err
} }
CacheSaveGroup(group) CacheSaveGroup(group)
@@ -214,6 +214,33 @@ func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request)
return return
} }
targetId, err := ConvertStringUint32(request.FormValue("recipientid"))
if err != nil {
http.Error(response, "invalid recipient id", http.StatusBadRequest)
return
}
target, err := CacheGetUserById(targetId)
if err != nil {
target = &User{Id: targetId}
err = DbUserGetById(ctx, target)
if err != nil {
http.Error(response, "invalid recipient id", http.StatusBadRequest)
}
}
if user.Connections[target.Id] == nil {
http.Error(response, "invalid recipient id", http.StatusBadRequest)
return
}
message := request.FormValue("message")
if message == "" {
http.Error(response, "empty message", http.StatusBadRequest)
return
}
WsSendToUser(user, target, message)
} }
func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) { func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) {
@@ -243,15 +270,15 @@ func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) {
user, err = CacheGetUserByName(username) user, err = CacheGetUserByName(username)
if err != nil { if err != nil {
user = &User{Name: username} user = &User{Name: username}
if err = DbGetUserByName(ctx, user); err != nil { if err = DbUserGetByName(ctx, user); err != nil {
http.Error(response, "bad login1", http.StatusUnauthorized) http.Error(response, "bad login1", http.StatusUnauthorized)
return return
} }
if err = DbGetUserGroups(ctx, user); err != nil { if err = DbUserGetGroups(ctx, user); err != nil {
http.Error(response, "bad login1", http.StatusUnauthorized) http.Error(response, "bad login1", http.StatusUnauthorized)
return return
} }
if err = DbGetUserConnections(ctx, user); err != nil { if err = DbUserGetConnections(ctx, user); err != nil {
http.Error(response, "bad login1", http.StatusUnauthorized) http.Error(response, "bad login1", http.StatusUnauthorized)
return return
} }
@@ -309,7 +336,7 @@ func HttpHandeGroupCreate(response http.ResponseWriter, request *http.Request) {
group.EnableUserColors = true group.EnableUserColors = true
} }
err = DbGroupSaveWithoutUsers(ctx, &group) err = DbGroupSave(ctx, &group)
if err != nil { if err != nil {
http.Error(response, err.Error(), http.StatusInternalServerError) http.Error(response, err.Error(), http.StatusInternalServerError)
return return
@@ -486,7 +513,7 @@ func HttpHandleGroupChangeOwner(response http.ResponseWriter, request *http.Requ
newOwner, err := CacheGetUserByName(newOwnerName) newOwner, err := CacheGetUserByName(newOwnerName)
if err != nil { if err != nil {
newOwner = &User{Name: newOwnerName} newOwner = &User{Name: newOwnerName}
err = DbGetUserByName(ctx, newOwner) err = DbUserGetByName(ctx, newOwner)
if err != nil { if err != nil {
http.Error(response, "user not in group", http.StatusBadRequest) http.Error(response, "user not in group", http.StatusBadRequest)
return return
@@ -517,30 +544,33 @@ func HttpHandleGroupMessage(response http.ResponseWriter, request *http.Request)
} }
ctx := request.Context() ctx := request.Context()
user, err := getUser(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized) http.Error(response, "invalid token", http.StatusBadRequest)
return
}
groupIdStr := request.FormValue("groupid")
groupId, err := ConvertStringUint32(groupIdStr)
if err != nil {
http.Error(response, "no such group", http.StatusUnauthorized)
return return
} }
targetStr := request.FormValue("subject") group, err := getGroup(ctx, groupId)
if targetStr == "" {
http.Error(response, "invalid subject", http.StatusBadRequest)
return
}
targetId, err := ConvertStringUint32(targetStr)
if err != nil {
http.Error(response, "invalid subject", http.StatusBadRequest)
return
}
content := request.FormValue("content") content := request.FormValue("content")
if content == "" { if content == "" {
http.Error(response, "invalid content", http.StatusBadRequest) http.Error(response, "empty message", http.StatusBadRequest)
return return
} }
err = WsSendToGroup(ctx, targetId, user.Id, content) _, ok := group.Users[user.Id]
if !ok {
http.Error(response, "no such group", http.StatusUnauthorized)
}
err = WsSendToGroup(group, user, content)
if err != nil { if err != nil {
http.Error(response, err.Error(), http.StatusBadRequest) http.Error(response, err.Error(), http.StatusBadRequest)
return return
+2 -2
View File
@@ -1,6 +1,6 @@
connection request/accept
direct messaging direct messaging
friend list who can send messages to who (configurable)
who can send messages to who
chat history chat history
+29 -41
View File
@@ -77,21 +77,16 @@ func sendToAllMessageCloseIfTimeout(message *map[string]any) {
} }
} }
func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message string) error { func WsSendToUser(from *User, to *User, message string) {
group, err := CacheGetGroup(groupId) var msg = map[string]any{
if err != nil { "type": WsMessageToUserFrom(DirectMessage_),
return errors.New("group invalid") "from": from.Id,
} "content": message,
sender, err := CacheGetUserById(senderId)
if err != nil {
sender = &User{Id: senderId}
err = DbUserSetById(ctx, sender)
if err != nil {
return errors.New("non existing sender")
}
} }
sendMessageCloseIfTimeout(from, &msg)
}
func WsSendToGroup(group *Group, sender *User, message string) error {
for groupUserId := range group.Users { for groupUserId := range group.Users {
groupUser, err := CacheGetUserById(groupUserId) groupUser, err := CacheGetUserById(groupUserId)
if err != nil || groupUser.Id == sender.Id { if err != nil || groupUser.Id == sender.Id {
@@ -99,9 +94,9 @@ func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message
} }
var msg = map[string]any{ var msg = map[string]any{
"from": "group", "type": WsMessageToUserFrom(Group_),
"group": group.Id, "from": group.Id,
"sender": sender.Name, "sender": sender.Id,
"content": message, "content": message,
} }
sendMessageCloseIfTimeout(groupUser, &msg) sendMessageCloseIfTimeout(groupUser, &msg)
@@ -118,7 +113,7 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *
token, ok := (*userMessage)["token"].(string) token, ok := (*userMessage)["token"].(string)
if !ok { if !ok {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "type": WsMessageToUserFrom(Server_),
"error": "no token in message", "error": "no token in message",
} }
sendMessageCloseIfTimeout(user, &msg) sendMessageCloseIfTimeout(user, &msg)
@@ -128,7 +123,7 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *
userId, err := TokenValidateGetId(token) userId, err := TokenValidateGetId(token)
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "type": WsMessageToUserFrom(Server_),
"error": "invalid token", "error": "invalid token",
} }
sendMessageCloseIfTimeout(user, &msg) sendMessageCloseIfTimeout(user, &msg)
@@ -137,28 +132,12 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *
userFromCache, err := CacheGetUserById(userId) userFromCache, err := CacheGetUserById(userId)
if err != nil { if err != nil {
dbUser := &User{Id: userId} var msg = map[string]any{
err = DbUserSetByIdWithoutGroupsConnections(ctx, dbUser) "type": WsMessageToUserFrom(Server_),
if err != nil { "error": "user not found",
var msg = map[string]any{
"from": "server",
"error": "invalid user data",
}
sendMessageCloseIfTimeout(user, &msg)
return false
} }
err = DbUserSetGroups(ctx, dbUser) sendMessageCloseIfTimeout(user, &msg)
if err != nil { return false
var msg = map[string]any{
"from": "server",
"error": "invalid user data",
}
sendMessageCloseIfTimeout(user, &msg)
return false
}
dbUser.WsConn = user.WsConn
CacheSaveUser(dbUser)
userFromCache = dbUser
} }
userFromCache.WsConn = user.WsConn userFromCache.WsConn = user.WsConn
@@ -169,10 +148,19 @@ func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *
if err != nil { if err != nil {
dbGroup := &Group{Id: groupId} dbGroup := &Group{Id: groupId}
err = DbGroupSetById(ctx, dbGroup) err = DbGroupGetById(ctx, dbGroup)
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "type": "server",
"error": "invalid user data",
}
sendMessageCloseIfTimeout(user, &msg)
return false
}
err = DbGroupGetMembers(ctx, dbGroup)
if err != nil {
var msg = map[string]any{
"type": "server",
"error": "invalid user data", "error": "invalid user data",
} }
sendMessageCloseIfTimeout(user, &msg) sendMessageCloseIfTimeout(user, &msg)