diff --git a/cache.go b/cache.go deleted file mode 100644 index 6bffd65..0000000 --- a/cache.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import ( - "context" - "errors" - "sync" -) - -var Groups map[uint64]ChatGroup -var ConnectedClients map[uint64]map[*Client]struct{} - -func InitCache() { - groups, err := GetAllChatGroups(context.Background()) - if err != nil { - panic(err) - } - - for _, group := range groups { - Groups[group.Id] = group - } -} - -func GetGroupById(groupId uint64) (*ChatGroup, error) { - group, ok := Groups[groupId] - if !ok { - return nil, errors.New("group not found") - } - return &group, nil -} - -func AddOrUpdateGroupToCache(mu *sync.Mutex, group ChatGroup) { - mu.Lock() - defer mu.Unlock() - - Groups[group.Id] = group -} - -func RemoveGroupFromCache(mu *sync.Mutex, groupId uint64) { - mu.Lock() - defer mu.Unlock() - - delete(Groups, groupId) -} - -func AddOrUpdateConnectedClientToCache(mu *sync.Mutex, client *Client) { - mu.Lock() - defer mu.Unlock() - - for _, groupId := range client.User.MemberGroupsId { - ConnectedClients[groupId][client] = struct{}{} - } -} - -func RemoveConnectedClientFromCache(mu *sync.Mutex, client *Client) { - mu.Lock() - defer mu.Unlock() - - for _, groupId := range client.User.MemberGroupsId { - delete(ConnectedClients[groupId], client) - } -} - -func IsUserInGivenGroup(mu *sync.Mutex, userId uint64, groupId uint64) bool { - mu.Lock() - defer mu.Unlock() - - group, ok := ConnectedClients[groupId] - if !ok { - return false - } - for client := range group { - if client.User.Id == userId { - return true - } - } - return false -} diff --git a/database.go b/database.go deleted file mode 100644 index 894654b..0000000 --- a/database.go +++ /dev/null @@ -1,216 +0,0 @@ -package main - -import ( - "context" - "errors" - - "golang.org/x/crypto/bcrypt" - - "github.com/jackc/pgx/v5" -) - -var dbConnection *pgx.Conn - -func InitDatabase(ctx context.Context) { - conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432") - if err != nil { - panic(err) - } - - _, err = conn.Exec(ctx, ` - CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, - name VARCHAR(20) UNIQUE NOT NULL, - pass_hash VARCHAR(60) NOT NULL, - color VARCHAR(3) DEFAULT NULL - ); - CREATE TABLE IF NOT EXISTS chat_groups ( - id SERIAL PRIMARY KEY, - name VARCHAR(48) NOT NULL, - creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - enable_user_colors BOOLEAN NOT NULL DEFAULT true, - group_color VARCHAR(3), - created_at TIMESTAMP NOT NULL DEFAULT NOW() - ); - CREATE TABLE IF NOT EXISTS chat_group_members ( - group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - joined_at TIMESTAMP NOT NULL DEFAULT NOW(), - PRIMARY KEY (group_id, user_id) - ); - `) - if err != nil { - panic(err) - } - - dbConnection = conn - InitCache() -} - -func AddNewUser(ctx context.Context, user *User) (uint64, error) { - var id uint64 - var err error - - if len(user.Name) == 0 || len(user.Name) > 20 { - return 0, errors.New("username bad length") - } - if user.IsPasswordHashed { - if len(user.Password) != 60 { - return 0, errors.New("password bad length") - } - } else { - var hashed []byte - hashed, err = bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) - if err != nil { - return 0, err - } - user.Password = string(hashed) - } - if user.Color == ([3]byte{}) { - user.Color = [3]byte{'x', 'x', 'x'} - } - c := string(user.Color[:]) - err = dbConnection.QueryRow(ctx, ` - INSERT INTO users (name, pass_hash, color) - VALUES ($1, $2, $3) - RETURNING id - `, user.Name, user.Password, c).Scan(&id) - return id, err -} - -func isPassValid(ctx context.Context, id uint64, plainPassword string) bool { - var controlHash string - err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash) - if err != nil { - return false - } - - return bcrypt.CompareHashAndPassword([]byte(controlHash), []byte(plainPassword)) == nil -} - -func GetAllUsers(ctx context.Context) ([]User, error) { - rows, err := dbConnection.Query(ctx, "SELECT id, name, color FROM users") - if err != nil { - return nil, err - } - defer rows.Close() - - var users []User - for rows.Next() { - var user User - if err := rows.Scan(&user.Id, &user.Name, &user.Color); err != nil { - return nil, err - } - users = append(users, user) - } - return users, rows.Err() -} - -func GetUserDataById(ctx context.Context, id uint64) (User, error) { - var user User - err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). - Scan(&user.Id, &user.Name, &user.Password, &user.Color) - if err != nil { - return User{}, err - } - user.IsPasswordHashed = true - return user, nil -} -func GetUserDataByName(ctx context.Context, name string) (User, error) { - var user User - err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE name = $1", name). - Scan(&user.Id, &user.Name, &user.Password, &user.Color) - if err != nil { - return User{}, err - } - user.IsPasswordHashed = true - return user, nil -} - -func CreateChatGroupWithoutMembers(ctx context.Context, group *ChatGroup) (uint64, error) { - if len(group.Name) < 1 { - return 0, errors.New("group name too short") - } - if len(group.Name) > 48 { - return 0, errors.New("group name too long") - } - - var id uint64 - err := dbConnection.QueryRow(ctx, `INSERT INTO chat_groups (name, creator_id, owner_id, created_at ) - VALUES ($1, $2, $3, $4) - RETURNING id - `, group.Name, group.CreatorId, group.OwnerId, group.CreatedAt).Scan(&id) - - if err != nil { - return 0, err - } - - AddOrUpdateGroupToCache(&mu, *group) - return id, err -} - -func GetAllChatGroups(ctx context.Context) ([]ChatGroup, error) { - rows, err := dbConnection.Query(ctx, "SELECT id, name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups") - if err != nil { - return nil, err - } - defer rows.Close() - - var groups []ChatGroup - for rows.Next() { - var group ChatGroup - if err := rows.Scan(&group.Id, &group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color, &group.CreatedAt); err != nil { - return nil, err - } - groups = append(groups, group) - } - return groups, rows.Err() -} - -func GetChatGroupWithoutMembers(ctx context.Context, id uint64) (ChatGroup, error) { - var group ChatGroup - err := dbConnection.QueryRow(ctx, `SELECT name, creator_id, owner_id, enable_user_colors, group_color, created_at FROM chat_groups WHERE id = $1`, - id).Scan(&group.Name, &group.CreatorId, &group.OwnerId, &group.EnableUserColors, &group.Color, &group.CreatedAt) - return group, err -} - -func GetUserMemberGroupIds(ctx context.Context, userId uint64) ([]uint64, error) { - rows, err := dbConnection.Query(ctx, "SELECT group_id FROM chat_group_members WHERE user_id = $1", userId) - if err != nil { - return nil, err - } - defer rows.Close() - - var groupIds []uint64 - for rows.Next() { - var groupId uint64 - if err := rows.Scan(&groupId); err != nil { - return nil, err - } - groupIds = append(groupIds, groupId) - } - return groupIds, rows.Err() -} - -func GetChatGroupMembers(ctx context.Context, groupId uint64) ([]User, error) { - rows, err := dbConnection.Query(ctx, ` - SELECT usr.id, usr.name, usr.color FROM users usr - JOIN chat_group_members members ON usr.id = members.user_id - WHERE members.group_id = $1 - `, groupId) - if err != nil { - return nil, err - } - defer rows.Close() - - var members []User - for rows.Next() { - var user User - if err := rows.Scan(&user.Id, &user.Name, &user.Color); err != nil { - return nil, err - } - members = append(members, user) - } - return members, rows.Err() -} diff --git a/enums.go b/enums.go deleted file mode 100644 index 88dcf59..0000000 --- a/enums.go +++ /dev/null @@ -1,8 +0,0 @@ -package main - -type serverResponseType struct { - MessageFromUser uint8 - BadRequest uint8 -} - -var ServerResponseType = serverResponseType{0, 1} diff --git a/go.mod b/go.mod index 06a6a5f..7f56037 100644 --- a/go.mod +++ b/go.mod @@ -2,15 +2,4 @@ module go-socket go 1.26 -require ( - github.com/coder/websocket v1.8.14 - github.com/golang-jwt/jwt/v5 v5.3.1 - github.com/jackc/pgx/v5 v5.8.0 - golang.org/x/crypto v0.48.0 -) - -require ( - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - golang.org/x/text v0.34.0 // indirect -) +require github.com/coder/websocket v1.8.14 diff --git a/go.sum b/go.sum deleted file mode 100644 index 73a59b1..0000000 --- a/go.sum +++ /dev/null @@ -1,32 +0,0 @@ -github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= -github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= -github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= -github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= -github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= -github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= -github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http.go b/http.go deleted file mode 100644 index eaf355d..0000000 --- a/http.go +++ /dev/null @@ -1,193 +0,0 @@ -package main - -import ( - "log" - "net/http" - "strconv" - "time" - - "golang.org/x/crypto/bcrypt" -) - -func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool { - if request.Method != http.MethodPost { - http.Error(*response, "POST only", http.StatusMethodNotAllowed) - return false - } - return true -} - -func RegisterHandler(response http.ResponseWriter, request *http.Request) { - if !isMethodAllowed(&response, request) { - return - } - ctx := request.Context() - username := request.FormValue("username") - password := request.FormValue("password") - - if len(username) < 2 { - http.Error(response, "no or short username", http.StatusBadRequest) - return - } - if username == "server" { - http.Error(response, "only server can use such name", http.StatusBadRequest) - return - } - if len(password) < 8 { - http.Error(response, "short or no password", http.StatusBadRequest) - return - } - - if _, err := GetUserDataByName(ctx, username); err == nil { - http.Error(response, "user already exists", http.StatusBadRequest) - return - } - - if _, err := AddNewUser(ctx, &User{ - Name: username, - Password: password, - IsPasswordHashed: false, - }); err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - log.Fatal(err) - return - } - - response.WriteHeader(http.StatusCreated) - _, err := response.Write([]byte("registered")) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } -} - -func LoginHandler(response http.ResponseWriter, request *http.Request) { - if !isMethodAllowed(&response, request) { - return - } - ctx := request.Context() - username := request.FormValue("username") - password := request.FormValue("password") - - respondBadLogin := func() { - http.Error(response, "bad login", http.StatusUnauthorized) - } - - if len(username) < 2 { - log.Printf("username<2") - respondBadLogin() - return - } - - user, err := GetUserDataByName(ctx, username) - if err != nil { - log.Printf("could not get user: %v", err) - respondBadLogin() - return - } - - if bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil { - token, err := GetToken(&user) - if err != nil { - respondBadLogin() - return - } - - if _, err = response.Write([]byte(token)); err != nil { - return - } - return - } - log.Printf("bad hash") - respondBadLogin() -} - -func CreateGroupHandler(response http.ResponseWriter, request *http.Request) { - if !isMethodAllowed(&response, request) { - return - } - ctx := request.Context() - - user, err := GetUserFromToken(request.FormValue("token")) - if err != nil { - http.Error(response, "invalid token", http.StatusUnauthorized) - return - } - - groupName := request.FormValue("name") - if len(groupName) < 2 { - http.Error(response, "no or too short group name", http.StatusBadRequest) - return - } - - _, err = CreateChatGroupWithoutMembers(ctx, &ChatGroup{ - Name: groupName, - CreatorId: user.Id, - OwnerId: user.Id, - CreatedAt: time.Now(), - }) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - log.Fatal(err) - return - } - response.WriteHeader(http.StatusCreated) -} - -func SendMessageToGroupHandler(response http.ResponseWriter, request *http.Request) { - groupIdString := request.PathValue("groupid") - if groupIdString == "" { - http.Error(response, "no group id", http.StatusBadRequest) - return - } - - var user User - var err error - token := request.FormValue("token") - if token == "" { - http.Error(response, "no token", http.StatusBadRequest) - return - } - if user, err = GetUserFromToken(token); err != nil { - http.Error(response, "invalid token", http.StatusUnauthorized) - return - } - - content := request.FormValue("content") - if content == "" { - http.Error(response, "no content", http.StatusBadRequest) - return - } - - groupId, err := strconv.ParseUint(groupIdString, 10, 64) - if err != nil { - http.Error(response, "no such group", http.StatusBadRequest) - return - } - - groupIds, err := GetUserMemberGroupIds(request.Context(), user.Id) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - isMember := false - for _, id := range groupIds { - if id == groupId { - isMember = true - break - } - } - if isMember { - var message = map[string]any{ - "type": ServerResponseType.MessageFromUser, - "content": content, - "from": user, - "time": time.Now().Unix(), - } - err := sendToGroup(groupId, user.Id, &message) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - } -} diff --git a/clientTest.py b/machine-client/clientTest.py similarity index 100% rename from clientTest.py rename to machine-client/clientTest.py diff --git a/main.go b/main.go index cafce58..7905807 100644 --- a/main.go +++ b/main.go @@ -1,38 +1,5 @@ package main -import ( - "context" - "log" - "net/http" -) - func main() { - InitDatabase(context.Background()) - srv := &wsServer{ - OnOpen: func(c *Client) { - AddOrUpdateConnectedClientToCache(&mu, c) - log.Println("client connected") - }, - OnClose: func(c *Client, err error) { - log.Println("client disconnected:", err) - RemoveConnectedClientFromCache(&mu, c) - }, - OnMessage: func(c *Client, msg map[string]any) { - log.Printf("received: %v\n", msg) - if c.User == nil { - handleUnauthenticatedMessage(c, msg) - } else { - handleAuthenticatedMessage(c, msg) - } - }, - } - http.Handle("/ws", srv) - log.Println("server listening on :8080") - http.HandleFunc("POST /new/account", RegisterHandler) - http.HandleFunc("POST /new/token", LoginHandler) - http.HandleFunc("POST /new/group", CreateGroupHandler) - http.HandleFunc("POST /new/messageto/group/{groupid}", SendMessageToGroupHandler) - - log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/structures.go b/structures.go deleted file mode 100644 index 77b4a14..0000000 --- a/structures.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -import ( - "time" - - "github.com/coder/websocket" -) - -type User struct { - MemberGroupsId []uint64 - Name string - Password string - Color [3]byte - Id uint64 - IsPasswordHashed bool -} -type Client struct { - conn *websocket.Conn - User *User -} - -type ChatGroup struct { - Members []User - CreatedAt time.Time - Name string - Id uint64 - CreatorId uint64 - OwnerId uint64 - Color [3]byte - EnableUserColors bool -} diff --git a/tokens.go b/tokens.go deleted file mode 100644 index 4471f63..0000000 --- a/tokens.go +++ /dev/null @@ -1,63 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "strconv" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -var secretKey = []byte("replace-with-env-variable") - -type UserClaims struct { - Name string `json:"name"` - Color string `json:"color"` - jwt.RegisteredClaims -} - -func GetToken(user *User) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, - UserClaims{ - Name: user.Name, - Color: string(user.Color[:]), - RegisteredClaims: jwt.RegisteredClaims{ - Subject: strconv.Itoa(int(user.Id)), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - }, - ) - return token.SignedString(secretKey) -} - -func GetUserFromToken(tokenString string) (User, error) { - token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return secretKey, nil - }) - if err != nil { - return User{}, err - } - - claims, ok := token.Claims.(*UserClaims) - if !ok || !token.Valid { - return User{}, errors.New("invalid token") - } - - id, err := strconv.ParseUint(claims.Subject, 10, 64) - if err != nil { - return User{}, fmt.Errorf("invalid subject: %w", err) - } - - var color [3]byte - copy(color[:], claims.Color) - return User{ - Id: id, - Name: claims.Name, - Color: color, - }, nil -} diff --git a/wsServer.go b/wsServer.go index cc3c93d..95b27e9 100644 --- a/wsServer.go +++ b/wsServer.go @@ -2,113 +2,41 @@ package main import ( "context" - "errors" "log" "net/http" - "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) -type wsServer struct { - OnOpen func(c *Client) - OnClose func(c *Client, err error) - OnMessage func(c *Client, msg map[string]any) -} - -var ( - mu sync.Mutex -) - -func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { - conn, err := websocket.Accept(responseWriter, request, &websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) +func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) { + connection, err := websocket.Accept(responseWriter, request, nil) if err != nil { - log.Println("accept error:", err) + log.Printf("websocket accept error: %v", err) return } - defer conn.CloseNow() - client := &Client{conn: conn} - - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - if s.OnOpen != nil { - s.OnOpen(client) - } - - var readErr error for { - var msg map[string]any - if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { - break + var clientMessage any + err := wsjson.Read(ctx, connection, &clientMessage) + if err != nil { + log.Printf("read error: %clientMessage", err) + return } - if s.OnMessage != nil { - s.OnMessage(client, msg) + log.Printf("received: %clientMessage", clientMessage) + + // process and optionally respond + err = wsjson.Write(ctx, connection, map[string]string{"status": "ok"}) + if err != nil { + log.Printf("write error: %clientMessage", err) + return } } - - cancel() // cancel before OnClose so any in-flight queries are canceled first - - if s.OnClose != nil { - s.OnClose(client, readErr) - } - - conn.Close(websocket.StatusNormalClosure, "done") } -func sendAndCloseIfFails(conn *websocket.Conn, message *map[string]any) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := wsjson.Write(ctx, conn, message); err != nil { - conn.Close(websocket.StatusGoingAway, "Write error") - } -} -func sendToGroup(id uint64, excludedUserId uint64, message *map[string]any) error { - if _, ok := Groups[id]; !ok { - return errors.New("Group Not Found") - } - - for client := range ConnectedClients[id] { - if client.User.Id != excludedUserId { - sendAndCloseIfFails(client.conn, message) - } - } - return nil -} - -func handleUnauthenticatedMessage(client *Client, msg map[string]any) { - token, ok := msg["token"].(string) - if !ok { - client.conn.Close(websocket.StatusGoingAway, "invalid token") - return - } - - user, err := GetUserFromToken(token) - if err != nil { - client.conn.Close(websocket.StatusPolicyViolation, "invalid token") - return - } - groupIds, err := GetUserMemberGroupIds(context.Background(), user.Id) - if err != nil { - client.conn.Close(websocket.StatusInternalError, "internal error") - return - } - user.MemberGroupsId = groupIds - client.User = &user - m := map[string]any{ - "authAs": user.Name, - } - sendAndCloseIfFails(client.conn, &m) - AddOrUpdateConnectedClientToCache(&mu, client) - log.Println("New User authenticated as: " + user.Name) -} - -func handleAuthenticatedMessage(client *Client, msg map[string]any) { - m := map[string]any{"temporary": "unauthorized"} - sendAndCloseIfFails(client.conn, &m) -} +func handleUnauthenticatedMessage() {} +func handleAuthenticatedMessage() {}