diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 3187230..0000000 --- a/TODO.md +++ /dev/null @@ -1,33 +0,0 @@ -# TODO — Code Logic Errors - -## Critical - -- [ ] **Login: nil pointer dereference** (`http.go:111`) - `CacheGetClientByName` returns `nil` on miss, then `DbSetClientByName` is called with that nil `client` → panic. Should query DB by username directly. - -- [ ] **Login: password never verified** (`http.go:87–131`) - No call to `PasswordVerify`/`bcrypt.CompareHashAndPassword`. Anyone with a valid username can log in. - -## High - -- [ ] **Login: validates `username` length instead of `password`** (`http.go:98`) - `if len(username) < 8` should be `if len(password) < 8`. Password is never length-checked. - -- [ ] **DB: missing `&` in `Scan` for `pronouns`** (`database.go:87`) - `client.Pronouns` should be `&client.Pronouns`. Compare with `DbSetClientById` which does it correctly. - -- [ ] **WS: 30s context kills entire connection** (`wsServer.go:23`) - A single 30s timeout context is shared across all reads in the loop. Should use per-read deadlines or `context.Background()` for the loop. - -## Medium - -- [ ] **NewUser: missing `return` after bad color error** (`http.go:54–56`) - On `parseRgb` error, `http.Error` is called but execution continues with `color = [0,0,0]`. - -- [ ] **WS: unauth disconnect deletes ID=0 from cache** (`wsServer.go:115`) - `closeConnection` calls `CacheDeleteClient(client.Id)` but unauthenticated clients have `Id=0`, wiping whatever sits at key 0. - -## Low - -- [ ] **`CacheSetGroup` is a no-op** (`cache.go:59`) - Function body is empty. The `Groups` cache is never populated, so every `CacheGetGroup` call misses and falls back to DB. diff --git a/cache.go b/cache.go index cd1bbf9..9b4bbd3 100644 --- a/cache.go +++ b/cache.go @@ -42,7 +42,7 @@ func CacheGetClientByName(name string) (*Client, error) { return nil, fmt.Errorf("client %s not found", name) } -func CacheSetClient(client *Client) { +func CacheSaveClient(client *Client) { mu.Lock() defer mu.Unlock() @@ -56,7 +56,12 @@ func CacheDeleteClient(id uint32) { delete(CacheClients, id) } -func CacheSetGroup() {} +func CacheSaveGroup(group *Group) { + mu.Lock() + defer mu.Unlock() + + Groups[group.Id] = group +} func CacheGetGroup(id uint32) (*Group, error) { mu.RLock() diff --git a/database.go b/database.go index f798dbe..6e0cab9 100644 --- a/database.go +++ b/database.go @@ -18,7 +18,7 @@ func DbInit(ctx context.Context) { } _, err = dbConn.Exec(ctx, ` - CREATE TABLE IF NOT EXISTS client ( + CREATE TABLE IF NOT EXISTS clients ( id SERIAL PRIMARY KEY, name VARCHAR(20) UNIQUE NOT NULL, pass_hash VARCHAR(60) NOT NULL, @@ -37,8 +37,8 @@ func DbInit(ctx context.Context) { CREATE TABLE IF NOT EXISTS chat_groups ( id SERIAL PRIMARY KEY, name VARCHAR(48) NOT NULL, - creator_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, - owner_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, + creator_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE, + owner_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE, enable_client_colors BOOLEAN NOT NULL DEFAULT true, color_red SMALLINT DEFAULT NULL, color_green SMALLINT DEFAULT NULL, @@ -53,7 +53,7 @@ func DbInit(ctx context.Context) { _, err = dbConn.Exec(ctx, ` CREATE TABLE IF NOT EXISTS chat_group_members ( group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES clients(id) ON DELETE CASCADE, joined_at TIMESTAMP NOT NULL DEFAULT NOW(), PRIMARY KEY (group_id, user_id) ) @@ -73,22 +73,17 @@ func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { return err } -func DbGetIdByClientName(ctx context.Context, name string) (uint32, error) { - var id uint32 - err := dbConn.QueryRow(ctx, ` - SELECT id FROM clients WHERE name = $1 - `, name).Scan(&id) - return id, err -} - func DbSetClientByName(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` - SELECT name, pass_hash, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1 - `, client.Name).Scan(&client.Name, &client.PasswordHash, &client.Pronouns, &client.Color[0], &client.Color[1], &client.Color[2], &client.CreatedAt) - return err + SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1 + `, client.Name).Scan(&client.Id, &client.Name, &client.PasswordHash, &client.Pronouns, &client.Color[0], &client.Color[1], &client.Color[2], &client.CreatedAt) + if err != nil { + return err + } + return DbSetClientGroups(ctx, client) } -func DbSetClientById(ctx context.Context, client *Client) error { +func DbSetClientByIdWithoutGroups(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE id = $1 `, client.Id).Scan(&client.Name, &client.PasswordHash, &client.Pronouns, &client.Color[0], &client.Color[1], &client.Color[2], &client.CreatedAt) @@ -102,10 +97,17 @@ func DbSaveGroupWithoutClients(ctx context.Context, group *Group) error { RETURNING id `, group.Name, group.CreatorId, group.OwnerId, group.EnableClientColors, group.Color[0], group.Color[1], group.Color[2], group.CreatedAt). Scan(&group.Id) + if err != nil { + return err + } + _, err = dbConn.Exec(ctx, ` + INSERT INTO chat_group_members (group_id, user_id, joined_at) + VALUES ($1, $2, $3) + `, group.Id, group.OwnerId, group.CreatedAt) return err } -func DbSetGroupById(ctx context.Context, group *Group) error { +func DbSetGroupByIdWithoutClients(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1 `, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableClientColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt) @@ -132,6 +134,34 @@ func DbSetGroupById(ctx context.Context, group *Group) error { return rows.Err() } +func DbSetGroupById(ctx context.Context, group *Group) error { + err := DbSetGroupByIdWithoutClients(ctx, group) + if err != nil { + return err + } + return DbSetGroupMemberClients(ctx, group) +} + +func DbSetGroupMemberClients(ctx context.Context, group *Group) error { + rows, err := dbConn.Query(ctx, ` + SELECT user_id FROM chat_group_members WHERE group_id = $1 + `, group.Id) + if err != nil { + return err + } + defer rows.Close() + + group.Clients = make(map[uint32]struct{}) + for rows.Next() { + var userId uint32 + if err := rows.Scan(&userId); err != nil { + return err + } + group.Clients[userId] = struct{}{} + } + return rows.Err() +} + func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds []uint32) error { batch := &pgx.Batch{} now := time.Now() @@ -151,3 +181,23 @@ func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds []uint32 } return nil } + +func DbSetClientGroups(ctx context.Context, client *Client) error { + rows, err := dbConn.Query(ctx, ` + SELECT group_id FROM chat_group_members WHERE user_id = $1 + `, client.Id) + if err != nil { + return err + } + defer rows.Close() + + client.Groups = make(map[uint32]struct{}) + for rows.Next() { + var groupId uint32 + if err := rows.Scan(&groupId); err != nil { + return err + } + client.Groups[groupId] = struct{}{} + } + return rows.Err() +} diff --git a/go-socket b/go-socket index 8d1d71e..80bc2b8 100755 Binary files a/go-socket and b/go-socket differ diff --git a/http.go b/http.go index 7f48e63..67ab799 100644 --- a/http.go +++ b/http.go @@ -117,15 +117,15 @@ func HttpHandleLogin(response http.ResponseWriter, request *http.Request) { err := DbSetClientByName(ctx, client) if err != nil { - http.Error(response, "bad login", http.StatusUnauthorized) + http.Error(response, "bad login1", http.StatusUnauthorized) return } - CacheSetClient(client) + CacheSaveClient(client) } err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password)) if err != nil { - http.Error(response, "bad login", http.StatusUnauthorized) + http.Error(response, "bad login2", http.StatusUnauthorized) return } @@ -178,7 +178,7 @@ func HttpHandleGroupCreate(response http.ResponseWriter, request *http.Request) if err == nil { client = *cacheClient } else { - err = DbSetClientById(ctx, &client) + err = DbSetClientByIdWithoutGroups(ctx, &client) if err != nil { http.Error(response, "internal server error", http.StatusInternalServerError) return @@ -191,6 +191,7 @@ func HttpHandleGroupCreate(response http.ResponseWriter, request *http.Request) OwnerId: clientId, CreatorId: clientId, Color: color, + Clients: map[uint32]struct{}{clientId: {}}, } enableClientColors := request.FormValue("enableClientColors") @@ -200,14 +201,14 @@ func HttpHandleGroupCreate(response http.ResponseWriter, request *http.Request) err = DbSaveGroupWithoutClients(ctx, &group) if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) + http.Error(response, err.Error(), http.StatusInternalServerError) return } groupIdBytes := make([]byte, 4) binary.BigEndian.PutUint32(groupIdBytes, group.Id) response.WriteHeader(http.StatusCreated) - response.Write([]byte(groupIdBytes)) + response.Write(groupIdBytes) } func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Request) { @@ -236,7 +237,7 @@ func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Reques if err == nil { group = *groupPtr } else { - err = DbSetGroupById(ctx, &group) + err = DbSetGroupByIdWithoutClients(ctx, &group) if err != nil { http.Error(response, "no such group", http.StatusUnauthorized) return diff --git a/machine-client/index.html b/machine-client/index.html index ea69602..e265772 100644 --- a/machine-client/index.html +++ b/machine-client/index.html @@ -72,13 +72,14 @@ +