diff --git a/cache.go b/cache.go index 4219dfa..6bffd65 100644 --- a/cache.go +++ b/cache.go @@ -6,8 +6,8 @@ import ( "sync" ) -var Groups map[uint32]ChatGroup -var ConnectedClients map[uint32]map[*Client]struct{} +var Groups map[uint64]ChatGroup +var ConnectedClients map[uint64]map[*Client]struct{} func InitCache() { groups, err := GetAllChatGroups(context.Background()) @@ -20,7 +20,7 @@ func InitCache() { } } -func GetGroupById(groupId uint32) (*ChatGroup, error) { +func GetGroupById(groupId uint64) (*ChatGroup, error) { group, ok := Groups[groupId] if !ok { return nil, errors.New("group not found") @@ -31,18 +31,21 @@ func GetGroupById(groupId uint32) (*ChatGroup, error) { func AddOrUpdateGroupToCache(mu *sync.Mutex, group ChatGroup) { mu.Lock() defer mu.Unlock() + Groups[group.Id] = group } -func RemoveGroupFromCache(mu *sync.Mutex, groupId uint32) { +func RemoveGroupFromCache(mu *sync.Mutex, groupId uint64) { mu.Lock() defer mu.Unlock() + delete(Groups, groupId) } func AddOrUpdateConnectedClientToCache(mu *sync.Mutex, client *Client) { mu.Lock() defer mu.Unlock() + for _, groupId := range client.User.MemberGroupsId { ConnectedClients[groupId][client] = struct{}{} } @@ -51,7 +54,24 @@ func AddOrUpdateConnectedClientToCache(mu *sync.Mutex, client *Client) { func RemoveConnectedClientFromCache(mu *sync.Mutex, client *Client) { mu.Lock() defer mu.Unlock() + for _, groupId := range client.User.MemberGroupsId { delete(ConnectedClients[groupId], client) } } + +func IsUserInGivenGroup(mu *sync.Mutex, userId uint64, groupId uint64) bool { + mu.Lock() + defer mu.Unlock() + + group, ok := ConnectedClients[groupId] + if !ok { + return false + } + for client := range group { + if client.User.Id == userId { + return true + } + } + return false +} diff --git a/database.go b/database.go index 777e09e..894654b 100644 --- a/database.go +++ b/database.go @@ -48,8 +48,8 @@ func InitDatabase(ctx context.Context) { InitCache() } -func AddNewUser(ctx context.Context, user *User) (uint32, error) { - var id uint32 +func AddNewUser(ctx context.Context, user *User) (uint64, error) { + var id uint64 var err error if len(user.Name) == 0 || len(user.Name) > 20 { @@ -68,17 +68,18 @@ func AddNewUser(ctx context.Context, user *User) (uint32, error) { user.Password = string(hashed) } if user.Color == ([3]byte{}) { - return 0, errors.New("color invalid") + user.Color = [3]byte{'x', 'x', 'x'} } + c := string(user.Color[:]) err = dbConnection.QueryRow(ctx, ` INSERT INTO users (name, pass_hash, color) VALUES ($1, $2, $3) RETURNING id - `, user.Name, user.Password, user.Color).Scan(&id) + `, user.Name, user.Password, c).Scan(&id) return id, err } -func isPassValid(ctx context.Context, id uint32, plainPassword string) bool { +func isPassValid(ctx context.Context, id uint64, plainPassword string) bool { var controlHash string err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash) if err != nil { @@ -106,7 +107,7 @@ func GetAllUsers(ctx context.Context) ([]User, error) { return users, rows.Err() } -func GetUserDataById(ctx context.Context, id uint32) (User, error) { +func GetUserDataById(ctx context.Context, id uint64) (User, error) { var user User err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). Scan(&user.Id, &user.Name, &user.Password, &user.Color) @@ -127,7 +128,7 @@ func GetUserDataByName(ctx context.Context, name string) (User, error) { return user, nil } -func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint32, error) { +func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint64, error) { if len(group.Name) < 1 { return 0, errors.New("group name too short") } @@ -135,12 +136,16 @@ func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint3 return 0, errors.New("group name too long") } - var id uint32 + var id uint64 err := dbConnection.QueryRow(ctx, `INSERT INTO chat_groups (name, creator_id, owner_id, created_at ) VALUES ($1, $2, $3, $4) RETURNING id `, group.Name, group.CreatorId, group.OwnerId, group.CreatedAt).Scan(&id) + if err != nil { + return 0, err + } + AddOrUpdateGroupToCache(&mu, *group) return id, err } @@ -163,14 +168,32 @@ func GetAllChatGroups(ctx context.Context) ([]ChatGroup, error) { return groups, rows.Err() } -func GetChatGroupWithoutMembers(ctx context.Context, id uint32) (ChatGroup, error) { +func GetChatGroupWithoutMembers(ctx context.Context, id uint64) (ChatGroup, error) { var group ChatGroup err := dbConnection.QueryRow(ctx, `SELECT name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups WHERE id = $1`, id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color, &group.CreatedAt) return group, err } -func GetChatGroupMembers(ctx context.Context, groupId uint32) ([]User, error) { +func GetUserMemberGroupIds(ctx context.Context, userId uint64) ([]uint64, error) { + rows, err := dbConnection.Query(ctx, "SELECT group_id FROM chat_group_members WHERE user_id = $1", userId) + if err != nil { + return nil, err + } + defer rows.Close() + + var groupIds []uint64 + for rows.Next() { + var groupId uint64 + if err := rows.Scan(&groupId); err != nil { + return nil, err + } + groupIds = append(groupIds, groupId) + } + return groupIds, rows.Err() +} + +func GetChatGroupMembers(ctx context.Context, groupId uint64) ([]User, error) { rows, err := dbConnection.Query(ctx, ` SELECT usr.id, usr.name, usr.color FROM users usr JOIN chat_group_members members ON usr.id = members.user_id diff --git a/enums.go b/enums.go new file mode 100644 index 0000000..88dcf59 --- /dev/null +++ b/enums.go @@ -0,0 +1,8 @@ +package main + +type serverResponseType struct { + MessageFromUser uint8 + BadRequest uint8 +} + +var ServerResponseType = serverResponseType{0, 1} diff --git a/go-socket b/go-socket index 27a7243..5adad46 100755 Binary files a/go-socket and b/go-socket differ diff --git a/http.go b/http.go index 066694e..eaf355d 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package main import ( "log" "net/http" + "strconv" "time" "golang.org/x/crypto/bcrypt" @@ -30,6 +31,7 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) { } if username == "server" { http.Error(response, "only server can use such name", http.StatusBadRequest) + return } if len(password) < 8 { http.Error(response, "short or no password", http.StatusBadRequest) @@ -72,12 +74,14 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) { } if len(username) < 2 { + log.Printf("username<2") respondBadLogin() return } user, err := GetUserDataByName(ctx, username) if err != nil { + log.Printf("could not get user: %v", err) respondBadLogin() return } @@ -94,6 +98,7 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) { } return } + log.Printf("bad hash") respondBadLogin() } @@ -129,10 +134,11 @@ func CreateGroupHandler(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusCreated) } -func SendMessageHandler(response http.ResponseWriter, request *http.Request) { - groupId := request.PathValue("groupid") - if groupId == "" { +func SendMessageToGroupHandler(response http.ResponseWriter, request *http.Request) { + groupIdString := request.PathValue("groupid") + if groupIdString == "" { http.Error(response, "no group id", http.StatusBadRequest) + return } var user User @@ -142,7 +148,7 @@ func SendMessageHandler(response http.ResponseWriter, request *http.Request) { http.Error(response, "no token", http.StatusBadRequest) return } - if user, err = GetUserFromToken(token); err != nil || user == nil { + if user, err = GetUserFromToken(token); err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } @@ -153,8 +159,35 @@ func SendMessageHandler(response http.ResponseWriter, request *http.Request) { return } - var isInGroup bool - for _, groupId := range user.MemberGroupsId { + groupId, err := strconv.ParseUint(groupIdString, 10, 64) + if err != nil { + http.Error(response, "no such group", http.StatusBadRequest) + return + } + groupIds, err := GetUserMemberGroupIds(request.Context(), user.Id) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + isMember := false + for _, id := range groupIds { + if id == groupId { + isMember = true + break + } + } + if isMember { + var message = map[string]any{ + "type": ServerResponseType.MessageFromUser, + "content": content, + "from": user, + "time": time.Now().Unix(), + } + err := sendToGroup(groupId, user.Id, &message) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } } } diff --git a/main.go b/main.go index a34e5d4..cafce58 100644 --- a/main.go +++ b/main.go @@ -32,7 +32,7 @@ func main() { http.HandleFunc("POST /new/account", RegisterHandler) http.HandleFunc("POST /new/token", LoginHandler) http.HandleFunc("POST /new/group", CreateGroupHandler) - http.HandleFunc("POST /new/messageto/{groupid}", SendMessageHandler) + http.HandleFunc("POST /new/messageto/group/{groupid}", SendMessageToGroupHandler) log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/structures.go b/structures.go index 08fde35..77b4a14 100644 --- a/structures.go +++ b/structures.go @@ -7,11 +7,11 @@ import ( ) type User struct { - MemberGroupsId []uint32 + MemberGroupsId []uint64 Name string Password string Color [3]byte - Id uint32 + Id uint64 IsPasswordHashed bool } type Client struct { @@ -23,9 +23,9 @@ type ChatGroup struct { Members []User CreatedAt time.Time Name string - Id uint32 - CreatorId uint32 - OwnerId uint32 + Id uint64 + CreatorId uint64 + OwnerId uint64 Color [3]byte EnableUserColors bool } diff --git a/tokens.go b/tokens.go index 93bb2a0..4471f63 100644 --- a/tokens.go +++ b/tokens.go @@ -48,7 +48,7 @@ func GetUserFromToken(tokenString string) (User, error) { return User{}, errors.New("invalid token") } - id, err := strconv.ParseUint(claims.Subject, 10, 32) + id, err := strconv.ParseUint(claims.Subject, 10, 64) if err != nil { return User{}, fmt.Errorf("invalid subject: %w", err) } @@ -56,7 +56,7 @@ func GetUserFromToken(tokenString string) (User, error) { var color [3]byte copy(color[:], claims.Color) return User{ - Id: uint32(id), + Id: id, Name: claims.Name, Color: color, }, nil diff --git a/wsServer.go b/wsServer.go index 52a360c..cc3c93d 100644 --- a/wsServer.go +++ b/wsServer.go @@ -68,7 +68,7 @@ func sendAndCloseIfFails(conn *websocket.Conn, message *map[string]any) { conn.Close(websocket.StatusGoingAway, "Write error") } } -func sendToGroup(id uint32, excludedUserId uint32, message *map[string]any) error { +func sendToGroup(id uint64, excludedUserId uint64, message *map[string]any) error { if _, ok := Groups[id]; !ok { return errors.New("Group Not Found") } @@ -82,12 +82,23 @@ func sendToGroup(id uint32, excludedUserId uint32, message *map[string]any) erro } func handleUnauthenticatedMessage(client *Client, msg map[string]any) { - token := msg["token"].(string) + token, ok := msg["token"].(string) + if !ok { + client.conn.Close(websocket.StatusGoingAway, "invalid token") + return + } + user, err := GetUserFromToken(token) if err != nil { client.conn.Close(websocket.StatusPolicyViolation, "invalid token") return } + groupIds, err := GetUserMemberGroupIds(context.Background(), user.Id) + if err != nil { + client.conn.Close(websocket.StatusInternalError, "internal error") + return + } + user.MemberGroupsId = groupIds client.User = &user m := map[string]any{ "authAs": user.Name,