diff --git a/database.go b/database.go index 3229985..3a43ea5 100644 --- a/database.go +++ b/database.go @@ -63,6 +63,8 @@ func DbInit(ctx context.Context) { } } +// DbSaveClientWithoutGroups saves client in db without groups and sets its id +// return: error if not successful func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at) @@ -73,6 +75,8 @@ func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { return err } +// DbSetClientByName sets all fields of given struct with database's data using name +// return: error if not successful func DbSetClientByName(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1 @@ -83,6 +87,8 @@ func DbSetClientByName(ctx context.Context, client *Client) error { return DbSetClientGroups(ctx, client) } +// DbSetClientByIdWithoutGroups sets all fields of given struct with database's data using id, excluding groups +// return: error if not successful func DbSetClientByIdWithoutGroups(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE id = $1 @@ -90,6 +96,8 @@ func DbSetClientByIdWithoutGroups(ctx context.Context, client *Client) error { return err } +// DbSetClientById sets all fields of given struct with database's data using id, including groups +// return: error if not successful func DbSetClientById(ctx context.Context, client *Client) error { err := DbSetClientByIdWithoutGroups(ctx, client) if err != nil { @@ -98,6 +106,8 @@ func DbSetClientById(ctx context.Context, client *Client) error { return DbSetClientGroups(ctx, client) } +// DbSaveGroupWithoutClients saves group in db and sets its id, also adds the owner as first member +// return: error if not successful func DbSaveGroupWithoutClients(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at) @@ -115,6 +125,8 @@ func DbSaveGroupWithoutClients(ctx context.Context, group *Group) error { return err } +// DbSetGroupByIdWithoutClients sets all fields of given struct with database's data using id, populates Clients map with member ids but not their data +// return: error if not successful func DbSetGroupByIdWithoutClients(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1 @@ -142,6 +154,8 @@ func DbSetGroupByIdWithoutClients(ctx context.Context, group *Group) error { return rows.Err() } +// DbSetGroupById sets all fields of given struct with database's data using id, including full member client data +// return: error if not successful func DbSetGroupById(ctx context.Context, group *Group) error { err := DbSetGroupByIdWithoutClients(ctx, group) if err != nil { @@ -150,6 +164,8 @@ func DbSetGroupById(ctx context.Context, group *Group) error { return DbSetGroupMemberClients(ctx, group) } +// DbSetGroupMemberClients populates group's Clients map with ids of all members from database +// return: error if not successful func DbSetGroupMemberClients(ctx context.Context, group *Group) error { rows, err := dbConn.Query(ctx, ` SELECT user_id FROM chat_group_members WHERE group_id = $1 @@ -170,19 +186,26 @@ func DbSetGroupMemberClients(ctx context.Context, group *Group) error { return rows.Err() } -func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds []uint32) error { +// DbAddClientsToGroup adds given clients to group in db, silently ignores already existing members +// return: error if not successful +func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds *[MaxClientsInGroup]uint32) error { batch := &pgx.Batch{} now := time.Now() + var count int for _, cid := range clientIds { + if cid == 0 { + continue + } batch.Queue(` INSERT INTO chat_group_members (group_id, user_id, joined_at) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING `, groupId, cid, now) + count++ } br := dbConn.SendBatch(ctx, batch) defer br.Close() - for range clientIds { + for range count { if _, err := br.Exec(); err != nil { return err } @@ -190,6 +213,35 @@ func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds []uint32 return nil } +// DbRemoveClientsFromGroup removes given clients from group in db, silently ignores not existing members +// return: deleted clients count, error if not successful +func DbRemoveClientsFromGroup(ctx context.Context, groupId uint32, clientIds *[MaxClientsInGroup]uint32) (int, error) { + batch := &pgx.Batch{} + var count int + for _, cid := range clientIds { + if cid == 0 { + continue + } + batch.Queue(` + DELETE FROM chat_group_members WHERE group_id = $1 AND user_id = $2 + `, groupId, cid) + count++ + } + br := dbConn.SendBatch(ctx, batch) + defer br.Close() + var deleted int + for range count { + tag, err := br.Exec() + if err != nil { + return deleted, err + } + deleted += int(tag.RowsAffected()) + } + return deleted, nil +} + +// DbSetClientGroups populates client's Groups map with ids of all groups the client belongs to from database +// return: error if not successful func DbSetClientGroups(ctx context.Context, client *Client) error { rows, err := dbConn.Query(ctx, ` SELECT group_id FROM chat_group_members WHERE user_id = $1 diff --git a/enums.go b/enums.go index ad5d524..1401a4d 100644 --- a/enums.go +++ b/enums.go @@ -9,7 +9,7 @@ const ( var Colors = map[string][3]uint8{ "red": {255, 0, 0}, - "green": {0, 255, 255}, + "green": {0, 255, 0}, "blue": {0, 0, 255}, "default": {255, 255, 255}, } diff --git a/globals.go b/globals.go index 2786ad0..a0e012d 100644 --- a/globals.go +++ b/globals.go @@ -1,6 +1,6 @@ package main const ( - MaxGroupsForClient uint8 = 8 - MaxClientsInGroup uint8 = 12 + MaxGroupsForClient uint32 = 8 + MaxClientsInGroup uint32 = 12 ) diff --git a/go-socket b/go-socket index a65b77c..8caf477 100755 Binary files a/go-socket and b/go-socket differ diff --git a/http.go b/http.go index b3c9398..1ea053c 100644 --- a/http.go +++ b/http.go @@ -256,38 +256,47 @@ func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Reques return } - usersToAddString := request.FormValue("users") + clientsString := request.FormValue("clients") var remainingUsersCount = int(MaxClientsInGroup) - len(group.Clients) if remainingUsersCount < 1 { http.Error(response, "max users", http.StatusUnauthorized) return } - userIdStrings := strings.SplitN(usersToAddString, ",", remainingUsersCount+1) - if len(userIdStrings) == 0 { + clientsStringSlice := strings.SplitN(clientsString, ",", remainingUsersCount+1) + if len(clientsStringSlice) == 0 { http.Error(response, "no users to add", http.StatusBadRequest) return } - clientIds := make([]uint32, 0, len(userIdStrings)) - for _, s := range userIdStrings { + var ids [MaxClientsInGroup]uint32 + var idx uint32 = 0 + for _, s := range clientsStringSlice { + if idx >= MaxClientsInGroup { + break + } id, err := ConvertStringUint32(strings.TrimSpace(s)) if err != nil { continue } - clientIds = append(clientIds, id) + ids[idx] = id + idx++ } - if len(clientIds) == 0 { + if idx == 0 { http.Error(response, "no valid users", http.StatusBadRequest) return } - err = DbAddClientsToGroup(ctx, group.Id, clientIds) + err = DbAddClientsToGroup(ctx, group.Id, &ids) if err != nil { http.Error(response, "internal server error", http.StatusInternalServerError) return } + for i := uint32(0); i < idx; i++ { + group.Clients[ids[i]] = struct{}{} + } + response.WriteHeader(http.StatusAccepted) _, err = response.Write([]byte("ok")) if err != nil { @@ -296,6 +305,73 @@ func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Reques } } +func HttpHandleGroupRemoveClient(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + + token := request.FormValue("token") + clientId, err := TokenValidateGetId(token) + if err != nil { + http.Error(response, "invalid token", http.StatusUnauthorized) + return + } + + affectedGroupId, err := ConvertStringUint32(request.FormValue("groupid")) + if err != nil { + http.Error(response, "no such group", http.StatusUnauthorized) + return + } + + ctx := request.Context() + + group, err := getGroup(ctx, affectedGroupId) + if err != nil { + http.Error(response, "no such group", http.StatusUnauthorized) + return + } + + if group.OwnerId != clientId { + http.Error(response, "no such group", http.StatusUnauthorized) + return + } + + clientsString := request.FormValue("clients") + + clientsStringSlice := strings.SplitN(clientsString, ",", int(MaxClientsInGroup)+1) + + var ids [MaxClientsInGroup]uint32 + var idx uint32 = 0 + for _, s := range clientsStringSlice { + if idx >= MaxClientsInGroup { + break + } + id, err := ConvertStringUint32(strings.TrimSpace(s)) + if err != nil { + continue + } + ids[idx] = id + idx++ + } + if idx == 0 { + http.Error(response, "no valid users", http.StatusBadRequest) + return + } + + count, err := DbRemoveClientsFromGroup(ctx, group.Id, &ids) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + for i := uint32(0); i < idx; i++ { + delete(group.Clients, ids[i]) + } + + response.WriteHeader(http.StatusAccepted) + response.Write([]byte(strconv.Itoa(count))) +} + func HttpHandleNewMessage(response http.ResponseWriter, request *http.Request) { if !isMethodAllowed(&response, request) { return @@ -417,6 +493,7 @@ func HttpHandleGroupMembersGet(response http.ResponseWriter, request *http.Reque json, err := json2.Marshal(groupMembers) if err != nil { http.Error(response, "internal server error", http.StatusInternalServerError) + return } response.WriteHeader(http.StatusAccepted) diff --git a/tests/.state b/tests/.state deleted file mode 100644 index cf8e299..0000000 --- a/tests/.state +++ /dev/null @@ -1,5 +0,0 @@ -TOKEN1=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxIiwiZXhwIjoxNzc1MDM3Mzc3LCJpYXQiOjE3NzUwMzM3Nzd9.BIZm-58PtXm13_q5O5M7B7YFjmYZFG0hE615POZ8xhY -USER1_ID=1 -TOKEN2=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIyIiwiZXhwIjoxNzc1MDM3Mzc3LCJpYXQiOjE3NzUwMzM3Nzd9.mUEEqxtbmmjwICEb_y2LhknR_I7Cis-5kSscm6it5bY -USER2_ID=2 -GROUP_ID=2 diff --git a/tests/04_add_user_to_group.sh b/tests/04_add_user_to_group.sh index e2036a6..b55387a 100755 --- a/tests/04_add_user_to_group.sh +++ b/tests/04_add_user_to_group.sh @@ -18,7 +18,7 @@ echo "=== Adding user2 (ID: $USER2_ID) to group $GROUP_ID ===" RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/mod/group/addclients" \ -d "token=$TOKEN1" \ -d "groupid=$GROUP_ID" \ - -d "users=$USER2_ID") + -d "clients=$USER2_ID") BODY=$(echo "$RESP" | head -1) CODE=$(echo "$RESP" | tail -1) echo "Response: $BODY (HTTP $CODE)" diff --git a/tests/08_get_group_members.sh b/tests/08_get_group_members.sh index 98ac632..bdecf0f 100755 --- a/tests/08_get_group_members.sh +++ b/tests/08_get_group_members.sh @@ -15,7 +15,7 @@ if [[ -z "$TOKEN1" || -z "$TOKEN2" || -z "$GROUP_ID" ]]; then fi echo "=== Getting members of group $GROUP_ID as user1 (owner) ===" -RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/groupmembers" \ +RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/group/members" \ -d "token=$TOKEN1" \ -d "group=$GROUP_ID") BODY=$(echo "$RESP" | head -1) @@ -29,7 +29,7 @@ fi echo "" echo "=== Getting members of group $GROUP_ID as user2 (member) ===" -RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/groupmembers" \ +RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/group/members" \ -d "token=$TOKEN2" \ -d "group=$GROUP_ID") BODY=$(echo "$RESP" | head -1) @@ -43,7 +43,7 @@ fi echo "" echo "=== Getting members with invalid token ===" -RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/groupmembers" \ +RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/group/members" \ -d "token=invalid_token" \ -d "group=$GROUP_ID") BODY=$(echo "$RESP" | head -1) @@ -57,7 +57,7 @@ fi echo "" echo "=== Getting members with invalid group ID ===" -RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/groupmembers" \ +RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/group/members" \ -d "token=$TOKEN1" \ -d "group=abc") BODY=$(echo "$RESP" | head -1) @@ -71,7 +71,7 @@ fi echo "" echo "=== Getting members of non-existent group ===" -RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/groupmembers" \ +RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/get/group/members" \ -d "token=$TOKEN1" \ -d "group=999999") BODY=$(echo "$RESP" | head -1) diff --git a/wsServer.go b/wsServer.go index a33d7fa..76b1ce0 100644 --- a/wsServer.go +++ b/wsServer.go @@ -30,7 +30,7 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request defer closeConnection(&client, ignoreCache) for { var clientMessage map[string]any - err := wsjson.Read(ctx, connection, &clientMessage) + err = wsjson.Read(ctx, connection, &clientMessage) if err != nil { log.Printf("read error: %v", err) return @@ -137,7 +137,6 @@ func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMes clientFromCache, err := CacheGetClientById(clientId) if err != nil { - // Not in cache — load from database dbClient := &Client{Id: clientId} err = DbSetClientByIdWithoutGroups(ctx, dbClient) if err != nil { @@ -162,6 +161,7 @@ func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMes clientFromCache = dbClient } + clientFromCache.WsConn = client.WsConn *client = *clientFromCache for groupId, _ := range clientFromCache.Groups {