change client to user

This commit is contained in:
2026-04-03 19:33:26 +02:00
parent 8d97e28dde
commit bd1168aef7
8 changed files with 212 additions and 213 deletions
+18 -18
View File
@@ -6,54 +6,54 @@ import (
) )
var ( var (
CacheClients = make(map[uint32]*Client) CacheUsers = make(map[uint32]*User)
mu sync.RWMutex mu sync.RWMutex
Groups = make(map[uint32]*Group) Groups = make(map[uint32]*Group)
) )
func CacheGetClientById(id uint32) (*Client, error) { func CacheGetUserById(id uint32) (*User, error) {
mu.RLock() mu.RLock()
defer mu.RUnlock() defer mu.RUnlock()
client, ok := CacheClients[id] user, ok := CacheUsers[id]
if !ok { if !ok {
return nil, fmt.Errorf("client %d not found", id) return nil, fmt.Errorf("user %d not found", id)
} }
return client, nil return user, nil
} }
func CacheGetIdByName(name string) (uint32, error) { func CacheGetIdByName(name string) (uint32, error) {
client, err := CacheGetClientByName(name) user, err := CacheGetUserByName(name)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return client.Id, nil return user.Id, nil
} }
func CacheGetClientByName(name string) (*Client, error) { func CacheGetUserByName(name string) (*User, error) {
mu.RLock() mu.RLock()
defer mu.RUnlock() defer mu.RUnlock()
for _, client := range CacheClients { for _, user := range CacheUsers {
if client.Name == name { if user.Name == name {
return client, nil return user, nil
} }
} }
return nil, fmt.Errorf("client %s not found", name) return nil, fmt.Errorf("user %s not found", name)
} }
func CacheSaveClient(client *Client) { func CacheSaveUser(user *User) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
CacheClients[client.Id] = client CacheUsers[user.Id] = user
} }
func CacheDeleteClient(id uint32) { func CacheDeleteUser(id uint32) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
delete(CacheClients, id) delete(CacheUsers, id)
} }
func CacheSaveGroup(group *Group) { func CacheSaveGroup(group *Group) {
+46 -46
View File
@@ -63,57 +63,57 @@ func DbInit(ctx context.Context) {
} }
} }
// DbSaveClientWithoutGroups saves client in db without groups and sets its id // DbSaveUserWithoutGroups saves user in db without groups and sets its id
// return: error if not successful // return: error if not successful
func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { func DbSaveUserWithoutGroups(ctx context.Context, user *User) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at) INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id RETURNING id
`, client.Name, client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt). `, user.Name, user.PasswordHash, user.Pronouns, user.Color[0], user.Color[1], user.Color[2], user.CreatedAt).
Scan(&client.Id) Scan(&user.Id)
return err return err
} }
// DbSetClientByName sets all fields of given struct with database's data using name // DbSetUserByName sets all fields of given struct with database's data using name
// return: error if not successful // return: error if not successful
func DbSetClientByName(ctx context.Context, client *Client) error { func DbSetUserByName(ctx context.Context, user *User) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1 SELECT id, name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1
`, client.Name).Scan(&client.Id, &client.Name, &client.PasswordHash, &client.Pronouns, &client.Color[0], &client.Color[1], &client.Color[2], &client.CreatedAt) `, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt)
if err != nil { if err != nil {
return err return err
} }
return DbSetClientGroups(ctx, client) return DbSetUserGroups(ctx, user)
} }
// DbSetClientByIdWithoutGroups sets all fields of given struct with database's data using id, excluding groups // DbSetUserByIdWithoutGroups sets all fields of given struct with database's data using id, excluding groups
// return: error if not successful // return: error if not successful
func DbSetClientByIdWithoutGroups(ctx context.Context, client *Client) error { func DbSetUserByIdWithoutGroups(ctx context.Context, user *User) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE id = $1 SELECT name, pass_hash, pronouns, color_red, color_green, color_blue, created_at FROM clients WHERE id = $1
`, client.Id).Scan(&client.Name, &client.PasswordHash, &client.Pronouns, &client.Color[0], &client.Color[1], &client.Color[2], &client.CreatedAt) `, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &user.Color[0], &user.Color[1], &user.Color[2], &user.CreatedAt)
return err return err
} }
// DbSetClientById sets all fields of given struct with database's data using id, including groups // DbSetUserById sets all fields of given struct with database's data using id, including groups
// return: error if not successful // return: error if not successful
func DbSetClientById(ctx context.Context, client *Client) error { func DbSetUserById(ctx context.Context, user *User) error {
err := DbSetClientByIdWithoutGroups(ctx, client) err := DbSetUserByIdWithoutGroups(ctx, user)
if err != nil { if err != nil {
return err return err
} }
return DbSetClientGroups(ctx, client) return DbSetUserGroups(ctx, user)
} }
// DbSaveGroupWithoutClients saves group in db and sets its id, also adds the owner as first member // DbSaveGroupWithoutUsers saves group in db and sets its id, also adds the owner as first member
// return: error if not successful // return: error if not successful
func DbSaveGroupWithoutClients(ctx context.Context, group *Group) error { func DbSaveGroupWithoutUsers(ctx context.Context, group *Group) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at) INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id RETURNING id
`, group.Name, group.CreatorId, group.OwnerId, group.EnableClientColors, group.Color[0], group.Color[1], group.Color[2], group.CreatedAt). `, group.Name, group.CreatorId, group.OwnerId, group.EnableUserColors, group.Color[0], group.Color[1], group.Color[2], group.CreatedAt).
Scan(&group.Id) Scan(&group.Id)
if err != nil { if err != nil {
return err return err
@@ -134,12 +134,12 @@ func DbDeleteGroup(ctx context.Context, group *Group) error {
return err return err
} }
// DbSetGroupByIdWithoutClients sets all fields of given struct with database's data using id, populates Clients map with member ids but not their data // DbSetGroupByIdWithoutUsers sets all fields of given struct with database's data using id, populates Users map with member ids but not their data
// return: error if not successful // return: error if not successful
func DbSetGroupByIdWithoutClients(ctx context.Context, group *Group) error { func DbSetGroupByIdWithoutUsers(ctx context.Context, group *Group) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1 SELECT name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at FROM chat_groups WHERE id = $1
`, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableClientColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt) `, group.Id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color[0], &group.Color[1], &group.Color[2], &group.CreatedAt)
if err != nil { if err != nil {
return err return err
} }
@@ -152,30 +152,30 @@ func DbSetGroupByIdWithoutClients(ctx context.Context, group *Group) error {
} }
defer rows.Close() defer rows.Close()
group.Clients = make(map[uint32]struct{}) group.Users = make(map[uint32]struct{})
for rows.Next() { for rows.Next() {
var userId uint32 var userId uint32
if err := rows.Scan(&userId); err != nil { if err := rows.Scan(&userId); err != nil {
return err return err
} }
group.Clients[userId] = struct{}{} group.Users[userId] = struct{}{}
} }
return rows.Err() return rows.Err()
} }
// DbSetGroupById sets all fields of given struct with database's data using id, including full member client data // DbSetGroupById sets all fields of given struct with database's data using id, including full member user data
// return: error if not successful // return: error if not successful
func DbSetGroupById(ctx context.Context, group *Group) error { func DbSetGroupById(ctx context.Context, group *Group) error {
err := DbSetGroupByIdWithoutClients(ctx, group) err := DbSetGroupByIdWithoutUsers(ctx, group)
if err != nil { if err != nil {
return err return err
} }
return DbSetGroupMemberClients(ctx, group) return DbSetGroupMemberUsers(ctx, group)
} }
// DbSetGroupMemberClients populates group's Clients map with ids of all members from database // DbSetGroupMemberUsers populates group's Users map with ids of all members from database
// return: error if not successful // return: error if not successful
func DbSetGroupMemberClients(ctx context.Context, group *Group) error { func DbSetGroupMemberUsers(ctx context.Context, group *Group) error {
rows, err := dbConn.Query(ctx, ` rows, err := dbConn.Query(ctx, `
SELECT user_id FROM chat_group_members WHERE group_id = $1 SELECT user_id FROM chat_group_members WHERE group_id = $1
`, group.Id) `, group.Id)
@@ -184,32 +184,32 @@ func DbSetGroupMemberClients(ctx context.Context, group *Group) error {
} }
defer rows.Close() defer rows.Close()
group.Clients = make(map[uint32]struct{}) group.Users = make(map[uint32]struct{})
for rows.Next() { for rows.Next() {
var userId uint32 var userId uint32
if err := rows.Scan(&userId); err != nil { if err := rows.Scan(&userId); err != nil {
return err return err
} }
group.Clients[userId] = struct{}{} group.Users[userId] = struct{}{}
} }
return rows.Err() return rows.Err()
} }
// DbAddClientsToGroup adds given clients to group in db, silently ignores already existing members // DbAddUsersToGroup adds given users to group in db, silently ignores already existing members
// return: error if not successful // return: error if not successful
func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds *[MaxClientsInGroup]uint32) error { func DbAddUsersToGroup(ctx context.Context, groupId uint32, userIds *[MaxUsersInGroup]uint32) error {
batch := &pgx.Batch{} batch := &pgx.Batch{}
now := time.Now() now := time.Now()
var count int var count int
for _, cid := range clientIds { for _, uid := range userIds {
if cid == 0 { if uid == 0 {
continue continue
} }
batch.Queue(` batch.Queue(`
INSERT INTO chat_group_members (group_id, user_id, joined_at) INSERT INTO chat_group_members (group_id, user_id, joined_at)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
`, groupId, cid, now) `, groupId, uid, now)
count++ count++
} }
br := dbConn.SendBatch(ctx, batch) br := dbConn.SendBatch(ctx, batch)
@@ -222,18 +222,18 @@ func DbAddClientsToGroup(ctx context.Context, groupId uint32, clientIds *[MaxCli
return nil return nil
} }
// DbRemoveClientsFromGroup removes given clients from group in db, silently ignores not existing members // DbRemoveUsersFromGroup removes given users from group in db, silently ignores not existing members
// return: deleted clients count, error if not successful // return: deleted users count, error if not successful
func DbRemoveClientsFromGroup(ctx context.Context, groupId uint32, clientIds *[MaxClientsInGroup]uint32) (int, error) { func DbRemoveUsersFromGroup(ctx context.Context, groupId uint32, userIds *[MaxUsersInGroup]uint32) (int, error) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
var count int var count int
for _, cid := range clientIds { for _, uid := range userIds {
if cid == 0 { if uid == 0 {
continue continue
} }
batch.Queue(` batch.Queue(`
DELETE FROM chat_group_members WHERE group_id = $1 AND user_id = $2 DELETE FROM chat_group_members WHERE group_id = $1 AND user_id = $2
`, groupId, cid) `, groupId, uid)
count++ count++
} }
br := dbConn.SendBatch(ctx, batch) br := dbConn.SendBatch(ctx, batch)
@@ -249,24 +249,24 @@ func DbRemoveClientsFromGroup(ctx context.Context, groupId uint32, clientIds *[M
return deleted, nil return deleted, nil
} }
// DbSetClientGroups populates client's Groups map with ids of all groups the client belongs to from database // DbSetUserGroups populates user's Groups map with ids of all groups the user belongs to from database
// return: error if not successful // return: error if not successful
func DbSetClientGroups(ctx context.Context, client *Client) error { func DbSetUserGroups(ctx context.Context, user *User) error {
rows, err := dbConn.Query(ctx, ` rows, err := dbConn.Query(ctx, `
SELECT group_id FROM chat_group_members WHERE user_id = $1 SELECT group_id FROM chat_group_members WHERE user_id = $1
`, client.Id) `, user.Id)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close() defer rows.Close()
client.Groups = make(map[uint32]struct{}) user.Groups = make(map[uint32]struct{})
for rows.Next() { for rows.Next() {
var groupId uint32 var groupId uint32
if err := rows.Scan(&groupId); err != nil { if err := rows.Scan(&groupId); err != nil {
return err return err
} }
client.Groups[groupId] = struct{}{} user.Groups[groupId] = struct{}{}
} }
return rows.Err() return rows.Err()
} }
+2 -2
View File
@@ -1,6 +1,6 @@
package main package main
const ( const (
MaxGroupsForClient uint32 = 8 MaxGroupsForUser uint32 = 8
MaxClientsInGroup uint32 = 12 MaxUsersInGroup uint32 = 12
) )
+78 -79
View File
@@ -22,23 +22,23 @@ func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool
return true return true
} }
func getClient(ctx context.Context, token string) (*Client, error) { func getUser(ctx context.Context, token string) (*User, error) {
clientId, err := TokenValidateGetId(token) userId, err := TokenValidateGetId(token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client, err := CacheGetClientById(clientId) user, err := CacheGetUserById(userId)
if err != nil { if err != nil {
client = &Client{Id: clientId} user = &User{Id: userId}
err = DbSetClientById(ctx, client) err = DbSetUserById(ctx, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
CacheSaveClient(client) CacheSaveUser(user)
} }
return client, nil return user, nil
} }
func getGroup(ctx context.Context, groupId uint32) (*Group, error) { func getGroup(ctx context.Context, groupId uint32) (*Group, error) {
@@ -54,15 +54,15 @@ func getGroup(ctx context.Context, groupId uint32) (*Group, error) {
return group, nil return group, nil
} }
func isOwner(client *Client, group *Group) bool { func isOwner(user *User, group *Group) bool {
if group.OwnerId == client.Id { if group.OwnerId == user.Id {
return true return true
} }
return false return false
} }
func getIfOwnerClientAndGroup(ctx context.Context, response *http.ResponseWriter, request *http.Request) (*Client, *Group, error) { func getIfOwnerUserAndGroup(ctx context.Context, response *http.ResponseWriter, request *http.Request) (*User, *Group, error) {
client, err := getClient(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(*response, "invalid token", http.StatusUnauthorized) http.Error(*response, "invalid token", http.StatusUnauthorized)
return nil, nil, err return nil, nil, err
@@ -80,14 +80,14 @@ func getIfOwnerClientAndGroup(ctx context.Context, response *http.ResponseWriter
return nil, nil, err return nil, nil, err
} }
if !isOwner(client, group) { if !isOwner(user, group) {
http.Error(*response, "no such group", http.StatusUnauthorized) http.Error(*response, "no such group", http.StatusUnauthorized)
return nil, nil, err return nil, nil, err
} }
return client, group, nil return user, group, nil
} }
func HttpHandleNewClient(response http.ResponseWriter, request *http.Request) { func HttpHandleNewUser(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) { if !isMethodAllowed(&response, request) {
return return
} }
@@ -116,7 +116,7 @@ func HttpHandleNewClient(response http.ResponseWriter, request *http.Request) {
return return
} }
newClient := &Client{ newUser := &User{
Name: username, Name: username,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Color: color, Color: color,
@@ -125,14 +125,13 @@ func HttpHandleNewClient(response http.ResponseWriter, request *http.Request) {
ctx := request.Context() ctx := request.Context()
err = DbSaveClientWithoutGroups(ctx, newClient) err = DbSaveUserWithoutGroups(ctx, newUser)
if err != nil { if err != nil {
http.Error(response, "name taken", http.StatusUnauthorized) http.Error(response, "name taken", http.StatusUnauthorized)
return return
} }
response.WriteHeader(http.StatusCreated) response.WriteHeader(http.StatusCreated)
response.Write([]byte("created"))
} }
func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) { func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) {
@@ -154,30 +153,30 @@ func HttpHandleNewToken(response http.ResponseWriter, request *http.Request) {
} }
var ( var (
client *Client user *User
err error err error
ctx = request.Context() ctx = request.Context()
) )
client, err = CacheGetClientByName(username) user, err = CacheGetUserByName(username)
if err != nil { if err != nil {
client = &Client{Name: username} user = &User{Name: username}
err := DbSetClientByName(ctx, client) err := DbSetUserByName(ctx, user)
if err != nil { if err != nil {
http.Error(response, "bad login1", http.StatusUnauthorized) http.Error(response, "bad login1", http.StatusUnauthorized)
return return
} }
CacheSaveClient(client) CacheSaveUser(user)
} }
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password)) err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
if err != nil { if err != nil {
http.Error(response, "bad login2", http.StatusUnauthorized) http.Error(response, "bad login2", http.StatusUnauthorized)
return return
} }
token, err := TokenCreate(client.Id) token, err := TokenCreate(user.Id)
if err != nil { if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError) http.Error(response, "internal server error", http.StatusInternalServerError)
return return
@@ -194,7 +193,7 @@ func HttpHandeGroupCreate(response http.ResponseWriter, request *http.Request) {
ctx := request.Context() ctx := request.Context()
client, err := getClient(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized) http.Error(response, "invalid token", http.StatusUnauthorized)
return return
@@ -211,18 +210,18 @@ func HttpHandeGroupCreate(response http.ResponseWriter, request *http.Request) {
group := Group{ group := Group{
Name: name, Name: name,
CreatedAt: time.Now(), CreatedAt: time.Now(),
OwnerId: client.Id, OwnerId: user.Id,
CreatorId: client.Id, CreatorId: user.Id,
Color: color, Color: color,
Clients: map[uint32]struct{}{client.Id: {}}, Users: map[uint32]struct{}{user.Id: {}},
} }
enableClientColors := request.FormValue("enableClientColors") enableUserColors := request.FormValue("enableUserColors")
if enableClientColors == "1" { if enableUserColors == "1" {
group.EnableClientColors = true group.EnableUserColors = true
} }
err = DbSaveGroupWithoutClients(ctx, &group) err = DbSaveGroupWithoutUsers(ctx, &group)
if err != nil { if err != nil {
http.Error(response, err.Error(), http.StatusInternalServerError) http.Error(response, err.Error(), http.StatusInternalServerError)
return return
@@ -237,7 +236,7 @@ func HttpHandleGroupRemove(response http.ResponseWriter, request *http.Request)
} }
ctx := request.Context() ctx := request.Context()
_, group, err := getIfOwnerClientAndGroup(ctx, &response, request) _, group, err := getIfOwnerUserAndGroup(ctx, &response, request)
if err != nil { if err != nil {
return return
} }
@@ -251,35 +250,35 @@ func HttpHandleGroupRemove(response http.ResponseWriter, request *http.Request)
response.WriteHeader(http.StatusAccepted) response.WriteHeader(http.StatusAccepted)
} }
func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Request) { func HttpHandleGroupAddUser(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) { if !isMethodAllowed(&response, request) {
return return
} }
ctx := request.Context() ctx := request.Context()
_, group, err := getIfOwnerClientAndGroup(ctx, &response, request) _, group, err := getIfOwnerUserAndGroup(ctx, &response, request)
if err != nil { if err != nil {
return return
} }
clientsString := request.FormValue("clients") usersString := request.FormValue("users")
var remainingUsersCount = int(MaxClientsInGroup) - len(group.Clients) var remainingUsersCount = int(MaxUsersInGroup) - len(group.Users)
if remainingUsersCount < 1 { if remainingUsersCount < 1 {
http.Error(response, "max users", http.StatusUnauthorized) http.Error(response, "max users", http.StatusUnauthorized)
return return
} }
clientsStringSlice := strings.SplitN(clientsString, ",", remainingUsersCount+1) usersStringSlice := strings.SplitN(usersString, ",", remainingUsersCount+1)
if len(clientsStringSlice) == 0 { if len(usersStringSlice) == 0 {
http.Error(response, "no users to add", http.StatusBadRequest) http.Error(response, "no users to add", http.StatusBadRequest)
return return
} }
var ids [MaxClientsInGroup]uint32 var ids [MaxUsersInGroup]uint32
var idx uint32 = 0 var idx uint32 = 0
for _, s := range clientsStringSlice { for _, s := range usersStringSlice {
if idx >= MaxClientsInGroup { if idx >= MaxUsersInGroup {
break break
} }
id, err := ConvertStringUint32(strings.TrimSpace(s)) id, err := ConvertStringUint32(strings.TrimSpace(s))
@@ -294,39 +293,39 @@ func HttpHandleGroupAddClient(response http.ResponseWriter, request *http.Reques
return return
} }
err = DbAddClientsToGroup(ctx, group.Id, &ids) err = DbAddUsersToGroup(ctx, group.Id, &ids)
if err != nil { if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError) http.Error(response, "internal server error", http.StatusInternalServerError)
return return
} }
for i := uint32(0); i < idx; i++ { for i := uint32(0); i < idx; i++ {
group.Clients[ids[i]] = struct{}{} group.Users[ids[i]] = struct{}{}
} }
response.WriteHeader(http.StatusAccepted) response.WriteHeader(http.StatusAccepted)
} }
func HttpHandleGroupRemoveClient(response http.ResponseWriter, request *http.Request) { func HttpHandleGroupRemoveUser(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) { if !isMethodAllowed(&response, request) {
return return
} }
ctx := request.Context() ctx := request.Context()
_, group, err := getIfOwnerClientAndGroup(ctx, &response, request) _, group, err := getIfOwnerUserAndGroup(ctx, &response, request)
if err != nil { if err != nil {
return return
} }
clientsString := request.FormValue("clients") usersString := request.FormValue("users")
clientsStringSlice := strings.SplitN(clientsString, ",", int(MaxClientsInGroup)+1) usersStringSlice := strings.SplitN(usersString, ",", int(MaxUsersInGroup)+1)
var ids [MaxClientsInGroup]uint32 var ids [MaxUsersInGroup]uint32
var idx uint32 = 0 var idx uint32 = 0
for _, s := range clientsStringSlice { for _, s := range usersStringSlice {
if idx >= MaxClientsInGroup { if idx >= MaxUsersInGroup {
break break
} }
id, err := ConvertStringUint32(strings.TrimSpace(s)) id, err := ConvertStringUint32(strings.TrimSpace(s))
@@ -341,14 +340,14 @@ func HttpHandleGroupRemoveClient(response http.ResponseWriter, request *http.Req
return return
} }
count, err := DbRemoveClientsFromGroup(ctx, group.Id, &ids) count, err := DbRemoveUsersFromGroup(ctx, group.Id, &ids)
if err != nil { if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError) http.Error(response, "internal server error", http.StatusInternalServerError)
return return
} }
for i := uint32(0); i < idx; i++ { for i := uint32(0); i < idx; i++ {
delete(group.Clients, ids[i]) delete(group.Users, ids[i])
} }
response.WriteHeader(http.StatusAccepted) response.WriteHeader(http.StatusAccepted)
@@ -361,7 +360,7 @@ func HttpHandleGroupChangeColor(response http.ResponseWriter, request *http.Requ
} }
ctx := request.Context() ctx := request.Context()
_, group, err := getIfOwnerClientAndGroup(ctx, &response, request) _, group, err := getIfOwnerUserAndGroup(ctx, &response, request)
if err != nil { if err != nil {
return return
} }
@@ -389,28 +388,28 @@ func HttpHandleGroupChangeOwner(response http.ResponseWriter, request *http.Requ
} }
ctx := request.Context() ctx := request.Context()
client, group, err := getIfOwnerClientAndGroup(ctx, &response, request) user, group, err := getIfOwnerUserAndGroup(ctx, &response, request)
if err != nil { if err != nil {
return return
} }
newOwnerName := request.FormValue("newOwner") newOwnerName := request.FormValue("newOwner")
newOwner, err := CacheGetClientByName(newOwnerName) newOwner, err := CacheGetUserByName(newOwnerName)
if err != nil { if err != nil {
newOwner = &Client{Name: newOwnerName} newOwner = &User{Name: newOwnerName}
err = DbSetClientByName(ctx, newOwner) err = DbSetUserByName(ctx, newOwner)
if err != nil { if err != nil {
http.Error(response, "client not in group", http.StatusBadRequest) http.Error(response, "user not in group", http.StatusBadRequest)
return return
} }
CacheSaveClient(client) CacheSaveUser(user)
} }
_, ok := group.Clients[newOwner.Id] _, ok := group.Users[newOwner.Id]
if !ok { if !ok {
http.Error(response, "client not in group", http.StatusBadRequest) http.Error(response, "user not in group", http.StatusBadRequest)
return return
} }
@@ -430,7 +429,7 @@ func HttpHandleNewMessage(response http.ResponseWriter, request *http.Request) {
} }
ctx := request.Context() ctx := request.Context()
client, err := getClient(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized) http.Error(response, "invalid token", http.StatusUnauthorized)
return return
@@ -453,7 +452,7 @@ func HttpHandleNewMessage(response http.ResponseWriter, request *http.Request) {
return return
} }
err = WsSendToGroup(ctx, targetId, client.Id, content) err = WsSendToGroup(ctx, targetId, user.Id, content)
if err != nil { if err != nil {
http.Error(response, err.Error(), http.StatusBadRequest) http.Error(response, err.Error(), http.StatusBadRequest)
return return
@@ -468,28 +467,28 @@ func HttpHandleGroupsGetWithoutMembers(response http.ResponseWriter, request *ht
ctx := request.Context() ctx := request.Context()
client, err := getClient(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized) http.Error(response, "invalid token", http.StatusUnauthorized)
return return
} }
groups := make([]GroupNoMembers, 0, len(client.Groups)) groups := make([]GroupNoMembers, 0, len(user.Groups))
for groupId := range client.Groups { for groupId := range user.Groups {
group, err := getGroup(ctx, groupId) group, err := getGroup(ctx, groupId)
if err != nil { if err != nil {
continue continue
} }
groups = append(groups, GroupNoMembers{ groups = append(groups, GroupNoMembers{
Id: groupId, Id: groupId,
Name: group.Name, Name: group.Name,
CreatedAt: group.CreatedAt, CreatedAt: group.CreatedAt,
CreatorId: group.CreatorId, CreatorId: group.CreatorId,
OwnerId: group.OwnerId, OwnerId: group.OwnerId,
Color: group.Color, Color: group.Color,
EnableClientsColors: group.EnableClientColors, EnableUsersColors: group.EnableUserColors,
}) })
} }
@@ -509,7 +508,7 @@ func HttpHandleGroupMembersGet(response http.ResponseWriter, request *http.Reque
} }
ctx := request.Context() ctx := request.Context()
client, err := getClient(ctx, request.FormValue("token")) user, err := getUser(ctx, request.FormValue("token"))
if err != nil { if err != nil {
http.Error(response, "invalid token", http.StatusUnauthorized) http.Error(response, "invalid token", http.StatusUnauthorized)
return return
@@ -522,7 +521,7 @@ func HttpHandleGroupMembersGet(response http.ResponseWriter, request *http.Reque
return return
} }
_, ok := client.Groups[groupId] _, ok := user.Groups[groupId]
if !ok { if !ok {
http.Error(response, "no such group", http.StatusUnauthorized) http.Error(response, "no such group", http.StatusUnauthorized)
return return
@@ -534,7 +533,7 @@ func HttpHandleGroupMembersGet(response http.ResponseWriter, request *http.Reque
return return
} }
groupMembers := slices.Collect(maps.Keys(group.Clients)) groupMembers := slices.Collect(maps.Keys(group.Users))
json, err := json2.Marshal(groupMembers) json, err := json2.Marshal(groupMembers)
if err != nil { if err != nil {
+3 -3
View File
@@ -17,12 +17,12 @@ func main() {
ctx := context.Background() ctx := context.Background()
DbInit(ctx) DbInit(ctx)
http.HandleFunc("/new/client", withCORS(HttpHandleNewClient)) http.HandleFunc("/new/user", withCORS(HttpHandleNewUser))
http.HandleFunc("/new/token", withCORS(HttpHandleNewToken)) http.HandleFunc("/new/token", withCORS(HttpHandleNewToken))
http.HandleFunc("/new/group", withCORS(HttpHandeGroupCreate)) http.HandleFunc("/new/group", withCORS(HttpHandeGroupCreate))
http.HandleFunc("/new/message", withCORS(HttpHandleNewMessage)) http.HandleFunc("/new/message", withCORS(HttpHandleNewMessage))
http.HandleFunc("/mod/group/addclients", withCORS(HttpHandleGroupAddClient)) http.HandleFunc("/mod/group/addusers", withCORS(HttpHandleGroupAddUser))
http.HandleFunc("/mod/group/removeclients", withCORS(HttpHandleGroupRemoveClient)) http.HandleFunc("/mod/group/removeusers", withCORS(HttpHandleGroupRemoveUser))
http.HandleFunc("/mod/group/color", withCORS(HttpHandleGroupChangeColor)) http.HandleFunc("/mod/group/color", withCORS(HttpHandleGroupChangeColor))
http.HandleFunc("/mod/group/owner", withCORS(HttpHandleGroupChangeOwner)) http.HandleFunc("/mod/group/owner", withCORS(HttpHandleGroupChangeOwner))
http.HandleFunc("/get/groups", withCORS(HttpHandleGroupsGetWithoutMembers)) http.HandleFunc("/get/groups", withCORS(HttpHandleGroupsGetWithoutMembers))
+16 -16
View File
@@ -6,7 +6,7 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
) )
type Client struct { type User struct {
Name string Name string
Pronouns string Pronouns string
PasswordHash string PasswordHash string
@@ -18,22 +18,22 @@ type Client struct {
} }
type Group struct { type Group struct {
Name string Name string
CreatedAt time.Time CreatedAt time.Time
Id uint32 Id uint32
CreatorId uint32 CreatorId uint32
OwnerId uint32 OwnerId uint32
Clients map[uint32]struct{} Users map[uint32]struct{}
Color [3]uint8 Color [3]uint8
EnableClientColors bool EnableUserColors bool
} }
type GroupNoMembers struct { type GroupNoMembers struct {
Name string `json:"name"` Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
Id uint32 `json:"id"` Id uint32 `json:"id"`
CreatorId uint32 `json:"creatorId"` CreatorId uint32 `json:"creatorId"`
OwnerId uint32 `json:"ownerId"` OwnerId uint32 `json:"ownerId"`
Color [3]uint8 `json:"color"` Color [3]uint8 `json:"color"`
EnableClientsColors bool `json:"enableClientsColors"` EnableUsersColors bool `json:"enableUsersColors"`
} }
+2 -2
View File
@@ -11,10 +11,10 @@ import (
const tokenSecret = "tmp" // TODO delete in production const tokenSecret = "tmp" // TODO delete in production
const tokenExpiration = time.Hour const tokenExpiration = time.Hour
func TokenCreate(clientId uint32) (string, error) { func TokenCreate(userId uint32) (string, error) {
now := time.Now() now := time.Now()
signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
Subject: strconv.FormatUint(uint64(clientId), 10), Subject: strconv.FormatUint(uint64(userId), 10),
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(tokenExpiration)), ExpiresAt: jwt.NewNumericDate(now.Add(tokenExpiration)),
}).SignedString([]byte(tokenSecret)) }).SignedString([]byte(tokenSecret))
+47 -47
View File
@@ -23,26 +23,26 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
var client = Client{WsConn: connection} var user = User{WsConn: connection}
var isAuthenticated bool var isAuthenticated bool
var ignoreCache bool var ignoreCache bool
defer closeConnection(&client, ignoreCache) defer closeConnection(&user, ignoreCache)
for { for {
var clientMessage map[string]any var userMessage map[string]any
err = wsjson.Read(ctx, connection, &clientMessage) err = wsjson.Read(ctx, connection, &userMessage)
if err != nil { if err != nil {
log.Printf("read error: %v", err) log.Printf("read error: %v", err)
return return
} }
if len(clientMessage) > 0 { if len(userMessage) > 0 {
if isAuthenticated { if isAuthenticated {
if !handleAuthenticatedMessage(&client, &clientMessage) { if !handleAuthenticatedMessage(&user, &userMessage) {
return return
} }
} else { } else {
if !handleUnauthenticatedMessage(ctx, &client, &clientMessage) { if !handleUnauthenticatedMessage(ctx, &user, &userMessage) {
ignoreCache = true ignoreCache = true
return return
} }
@@ -52,18 +52,18 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
} }
} }
func sendMessageCloseIfTimeout(client *Client, message *map[string]any) { func sendMessageCloseIfTimeout(user *User, message *map[string]any) {
if client.WsConn == nil { if user.WsConn == nil {
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
err := wsjson.Write(ctx, client.WsConn, message) err := wsjson.Write(ctx, user.WsConn, message)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
closeConnection(client, false) closeConnection(user, false)
} }
log.Printf("write error: %v", err) log.Printf("write error: %v", err)
} }
@@ -72,8 +72,8 @@ func sendMessageCloseIfTimeout(client *Client, message *map[string]any) {
func sendToAllMessageCloseIfTimeout(message *map[string]any) { func sendToAllMessageCloseIfTimeout(message *map[string]any) {
mu.RLock() mu.RLock()
defer mu.RUnlock() defer mu.RUnlock()
for _, client := range CacheClients { for _, user := range CacheUsers {
sendMessageCloseIfTimeout(client, message) sendMessageCloseIfTimeout(user, message)
} }
} }
@@ -83,88 +83,88 @@ func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message
return errors.New("group invalid") return errors.New("group invalid")
} }
client, err := CacheGetClientById(senderId) sender, err := CacheGetUserById(senderId)
if err != nil { if err != nil {
client = &Client{Id: senderId} sender = &User{Id: senderId}
err = DbSetClientById(ctx, client) err = DbSetUserById(ctx, sender)
if err != nil { if err != nil {
return errors.New("non existing sender") return errors.New("non existing sender")
} }
} }
for groupClientId := range group.Clients { for groupUserId := range group.Users {
groupClient, err := CacheGetClientById(groupClientId) groupUser, err := CacheGetUserById(groupUserId)
if err != nil || groupClient.Id == client.Id { if err != nil || groupUser.Id == sender.Id {
continue continue
} }
var msg = map[string]any{ var msg = map[string]any{
"from": "group", "from": "group",
"group": group.Id, "group": group.Id,
"sender": client.Name, "sender": sender.Name,
"content": message, "content": message,
} }
sendMessageCloseIfTimeout(groupClient, &msg) sendMessageCloseIfTimeout(groupUser, &msg)
} }
return nil return nil
} }
func handleAuthenticatedMessage(client *Client, clientMessage *map[string]any) bool { func handleAuthenticatedMessage(user *User, userMessage *map[string]any) bool {
sendMessageCloseIfTimeout(client, clientMessage) sendMessageCloseIfTimeout(user, userMessage)
return true return true
} }
func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMessage *map[string]any) bool { func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *map[string]any) bool {
token, ok := (*clientMessage)["token"].(string) token, ok := (*userMessage)["token"].(string)
if !ok { if !ok {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "from": "server",
"error": "no token in message", "error": "no token in message",
} }
sendMessageCloseIfTimeout(client, &msg) sendMessageCloseIfTimeout(user, &msg)
return false return false
} }
clientId, err := TokenValidateGetId(token) userId, err := TokenValidateGetId(token)
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "from": "server",
"error": "invalid token", "error": "invalid token",
} }
sendMessageCloseIfTimeout(client, &msg) sendMessageCloseIfTimeout(user, &msg)
return false return false
} }
clientFromCache, err := CacheGetClientById(clientId) userFromCache, err := CacheGetUserById(userId)
if err != nil { if err != nil {
dbClient := &Client{Id: clientId} dbUser := &User{Id: userId}
err = DbSetClientByIdWithoutGroups(ctx, dbClient) err = DbSetUserByIdWithoutGroups(ctx, dbUser)
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "from": "server",
"error": "invalid client data", "error": "invalid user data",
} }
sendMessageCloseIfTimeout(client, &msg) sendMessageCloseIfTimeout(user, &msg)
return false return false
} }
err = DbSetClientGroups(ctx, dbClient) err = DbSetUserGroups(ctx, dbUser)
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "from": "server",
"error": "invalid client data", "error": "invalid user data",
} }
sendMessageCloseIfTimeout(client, &msg) sendMessageCloseIfTimeout(user, &msg)
return false return false
} }
dbClient.WsConn = client.WsConn dbUser.WsConn = user.WsConn
CacheSaveClient(dbClient) CacheSaveUser(dbUser)
clientFromCache = dbClient userFromCache = dbUser
} }
clientFromCache.WsConn = client.WsConn userFromCache.WsConn = user.WsConn
*client = *clientFromCache *user = *userFromCache
for groupId, _ := range clientFromCache.Groups { for groupId, _ := range userFromCache.Groups {
_, err = CacheGetGroup(groupId) _, err = CacheGetGroup(groupId)
if err != nil { if err != nil {
dbGroup := &Group{Id: groupId} dbGroup := &Group{Id: groupId}
@@ -173,9 +173,9 @@ func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMes
if err != nil { if err != nil {
var msg = map[string]any{ var msg = map[string]any{
"from": "server", "from": "server",
"error": "invalid client data", "error": "invalid user data",
} }
sendMessageCloseIfTimeout(client, &msg) sendMessageCloseIfTimeout(user, &msg)
return false return false
} }
CacheSaveGroup(dbGroup) CacheSaveGroup(dbGroup)
@@ -184,9 +184,9 @@ func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMes
return true return true
} }
func closeConnection(client *Client, ignoreCache bool) { func closeConnection(user *User, ignoreCache bool) {
if !ignoreCache { if !ignoreCache {
CacheDeleteClient(client.Id) CacheDeleteUser(user.Id)
} }
client.WsConn.CloseNow() user.WsConn.CloseNow()
} }