diff --git a/cache.go b/cache.go index 013cd6c..64615ab 100644 --- a/cache.go +++ b/cache.go @@ -1,92 +1,37 @@ package main import ( - "context" - "errors" + "fmt" "sync" - - "golang.org/x/crypto/bcrypt" ) var ( - mu sync.RWMutex - clients = make(map[uint32]*Client) - chatGroups = make(map[uint32]ChatGroup) + CacheClients = make(map[uint32]*Client) + mu sync.RWMutex + Groups = make(map[uint32]*Group) ) -func CreateGroup(ctx context.Context, chatGroup *ChatGroup) { - mu.Lock() - defer mu.Unlock() - chatGroups[chatGroup.Id] = *chatGroup -} +func CacheGetClientById(id uint32) (*Client, error) { + mu.RLock() + defer mu.RUnlock() -func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) { - mu.Lock() - defer mu.Unlock() - delete(chatGroups, chatGroup.Id) -} - -func CreateClient(ctx context.Context, client *Client) error { - mu.Lock() - defer mu.Unlock() - clients[client.Id] = client - - hashed, err := bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost) - if err != nil { - return err - } - - client.PasswordHash = string(hashed) - - //err := DbSaveClientWithoutGroups(ctx, client) - //if err != nil { - // return err - //} - return nil -} - -func CheckPassword(ctx context.Context, id uint32, password string) error { - client, err := GetClientFromId(id) - if err != nil { - return err - } - - err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password)) - if err != nil { - return err - } - return nil -} - -func GetIdFromClientName(ctx context.Context, name string) (uint32, error) { - for _, client := range clients { - if client.Name == name { - return client.Id, nil - } - } - return 0, errors.New("client not found") -} - -func GetClientFromId(id uint32) (*Client, error) { - client, ok := clients[id] + client, ok := CacheClients[id] if !ok { - return nil, errors.New("no such user") + return nil, fmt.Errorf("client %d not found", id) } return client, nil } -func ConnectClientToGroups(ctx context.Context, client *Client) { +func CacheSetClient(id uint32, client *Client) { mu.Lock() defer mu.Unlock() - for _, groupIn := range client.Groups { - chatGroups[groupIn.Id].Members[client.Id] = client - } + + CacheClients[id] = client } -func DisconnectClientFromGroups(ctx context.Context, client *Client) { +func CacheDeleteClient(id uint32) { mu.Lock() defer mu.Unlock() - for _, groupIn := range client.Groups { - delete(chatGroups[groupIn.Id].Members, client.Id) - } + + delete(CacheClients, id) } diff --git a/database.go b/database.go index ba0f51c..264b630 100644 --- a/database.go +++ b/database.go @@ -3,66 +3,77 @@ package main import ( "context" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -var dbConn *pgx.Conn +var dbConn *pgxpool.Pool -func InitDatabase(ctx context.Context) { - conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432") +func DbInit(ctx context.Context) { + var err error + dbConn, err = pgxpool.New(ctx, "postgres://master:secret@localhost:5432") // TODO change to env in production if err != nil { panic(err) } - _, err = conn.Exec(ctx, ` - CREATE TABLE IF NOT EXISTS users ( + _, err = dbConn.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS client ( id SERIAL PRIMARY KEY, name VARCHAR(20) UNIQUE NOT NULL, pass_hash VARCHAR(60) NOT NULL, pronouns VARCHAR(15) DEFAULT NULL, - color VARCHAR(3) DEFAULT NULL - ); - CREATE TABLE IF NOT EXISTS chat_groups ( - id SERIAL PRIMARY KEY, - name VARCHAR(48) NOT NULL, - creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - enable_user_colors BOOLEAN NOT NULL DEFAULT true, - group_color VARCHAR(3), + color_red SMALLINT DEFAULT NULL, + color_green SMALLINT DEFAULT NULL, + color_blue SMALLINT DEFAULT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW() - ); - CREATE TABLE IF NOT EXISTS chat_group_members ( - group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - joined_at TIMESTAMP NOT NULL DEFAULT NOW(), - PRIMARY KEY (group_id, user_id) - ); + ) `) if err != nil { panic(err) } - dbConn = conn + _, err = dbConn.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS chat_groups ( + id SERIAL PRIMARY KEY, + name VARCHAR(48) NOT NULL, + creator_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, + owner_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, + enable_user_colors BOOLEAN NOT NULL DEFAULT true, + color_red SMALLINT DEFAULT NULL, + color_green SMALLINT DEFAULT NULL, + color_blue SMALLINT DEFAULT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW() + ) + `) + if err != nil { + panic(err) + } + + _, err = dbConn.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS chat_group_members ( + group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE, + joined_at TIMESTAMP NOT NULL DEFAULT NOW(), + PRIMARY KEY (group_id, user_id) + ) + `) + if err != nil { + panic(err) + } } func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { - var id uint64 - var err error - - err = dbConn.QueryRow(ctx, ` - INSERT INTO users (name, pass_hash, pronouns) - VALUES ($1, $2, $3) + err := dbConn.QueryRow(ctx, ` + INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id - `, client.Name, client.PasswordHash, client.Pronouns).Scan(&id) + `, client.Name, client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt). + Scan(&client.Id) return err } -func DbGetClient(ctx context.Context, id uint32, client *Client) error { +func DbGetClientByName(ctx context.Context, client *Client) error { err := dbConn.QueryRow(ctx, ` - SELECT name, pass_hash, pronouns, color WHERE id = $1 - `, id).Scan(client.Name, client.PasswordHash, client.Pronouns, client.Color) - if err != nil { - return err - } - return nil + SELECT name, pass_hash, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1 + `, client.Name).Scan(&client.Name, &client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt) + return err } diff --git a/globals.go b/globals.go new file mode 100644 index 0000000..2786ad0 --- /dev/null +++ b/globals.go @@ -0,0 +1,6 @@ +package main + +const ( + MaxGroupsForClient uint8 = 8 + MaxClientsInGroup uint8 = 12 +) diff --git a/go-socket b/go-socket index ee778cd..2704ae6 100755 Binary files a/go-socket and b/go-socket differ diff --git a/http.go b/http.go index c86de30..35b2697 100644 --- a/http.go +++ b/http.go @@ -1,6 +1,12 @@ package main -import "net/http" +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" +) func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool { if request.Method != http.MethodPost { @@ -10,7 +16,23 @@ func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool return true } -func RegisterHandler(response http.ResponseWriter, request *http.Request) { +func parseRgb(str string) ([3]uint8, error) { + parts := strings.SplitN(str, ",", 4) + if len(parts) != 3 { + return [3]uint8{}, fmt.Errorf("invalid rgb") + } + var rgb [3]uint8 + for i, p := range parts { + n, err := strconv.ParseUint(strings.TrimSpace(p), 10, 8) + if err != nil { + return [3]uint8{}, fmt.Errorf("invalid component %d: %w", i, err) + } + rgb[i] = uint8(n) + } + return rgb, nil +} + +func HttpHandleNewUser(response http.ResponseWriter, request *http.Request) { if !isMethodAllowed(&response, request) { return } @@ -29,21 +51,32 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) { return } - newClient := Client{ + color, err := parseRgb(request.FormValue("color")) + if err != nil { + http.Error(response, "bad color", http.StatusBadRequest) + } + + hashedPassword, err := PasswordHash(password) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + newClient := &Client{ Name: username, - PasswordHash: password, - Color: [3]uint8{120, 120, 120}, //"xxx" + PasswordHash: hashedPassword, + Color: color, + CreatedAt: time.Now(), } - err := CreateClient(ctx, &newClient) + err = DbSaveClientWithoutGroups(ctx, newClient) if err != nil { - http.Error(response, "taken", http.StatusBadRequest) + http.Error(response, "name taken", http.StatusInternalServerError) return } - response.Write([]byte("registered")) } -func LoginHandler(response http.ResponseWriter, request *http.Request) { +func HttpHandleLogin(response http.ResponseWriter, request *http.Request) { if !isMethodAllowed(&response, request) { return } @@ -62,25 +95,6 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) { return } - var client Client + _, err := CacheGetClientById() - id, err := GetIdFromClientName(ctx, username) - if err != nil { - http.Error(response, "bad login", http.StatusBadRequest) - return - } - - err = DbGetClient(ctx) - if err != nil { - http.Error(response, "bad login", http.StatusBadRequest) - return - } - - token, err := GetToken(id) - if err != nil { - http.Error(response, "Internal error", http.StatusInternalServerError) - return - } - - response.Write([]byte(token)) } diff --git a/main.go b/main.go index 94cccc7..7905807 100644 --- a/main.go +++ b/main.go @@ -1,19 +1,5 @@ package main -import ( - "context" - "log" - "net/http" -) - func main() { - ctx := context.Background() - InitDatabase(ctx) - http.HandleFunc("/ws", ServeWsConnection) - http.HandleFunc("/new/client", RegisterHandler) - http.HandleFunc("/new/token", LoginHandler) - - log.Println("listening on :8080") - log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/password.go b/password.go new file mode 100644 index 0000000..afee84f --- /dev/null +++ b/password.go @@ -0,0 +1,13 @@ +package main + +import "golang.org/x/crypto/bcrypt" + +func PasswordHash(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +func PasswordCheckAgainstHash(password, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/struct.go b/struct.go deleted file mode 100644 index db48d37..0000000 --- a/struct.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import "github.com/coder/websocket" - -type Client struct { - PasswordHash string - Name string - Pronouns string - Groups [12]*ChatGroup - Connection *websocket.Conn - Id uint32 - Color [3]byte - IsAuthenticated bool -} - -type ChatGroup struct { - Name string - Id uint32 - Members map[uint32]*Client - Color [3]byte -} diff --git a/structs.go b/structs.go new file mode 100644 index 0000000..18c0d1e --- /dev/null +++ b/structs.go @@ -0,0 +1,29 @@ +package main + +import ( + "time" + + "github.com/coder/websocket" +) + +type Client struct { + Name string + Pronouns string + PasswordHash string + CreatedAt time.Time + WsConn *websocket.Conn + Id uint32 + Groups [MaxGroupsForClient]uint32 + Color [3]uint8 +} + +type Group struct { + Name string + CreatedAt time.Time + Id uint32 + CreatorId uint32 + OwnerId uint32 + Clients [MaxClientsInGroup]uint32 + Color [3]uint8 + EnableUserColors bool +} diff --git a/tokens.go b/tokens.go index fc2415b..65d60ef 100644 --- a/tokens.go +++ b/tokens.go @@ -1,44 +1,56 @@ package main import ( + "fmt" "strconv" "time" "github.com/golang-jwt/jwt/v5" ) -var secretKey = []byte("replace-with-env-variable") +const tokenSecret = "tmp" // TODO delete in production +const tokenExpiration = time.Hour -func GetToken(clientId uint32) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, - jwt.RegisteredClaims{ - Subject: strconv.Itoa(int(clientId)), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - ) - return token.SignedString(secretKey) +func TokenCreate(clientId uint32) (string, error) { + now := time.Now() + signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Subject: strconv.FormatUint(uint64(clientId), 10), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(tokenExpiration)), + }).SignedString([]byte(tokenSecret)) + return signedToken, err } -func GetClientIdFromToken(token string) (uint32, error) { - parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(t *jwt.Token) (any, error) { +func TokenValidateGetId(tokenString string) (uint32, error) { + token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, jwt.ErrSignatureInvalid + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } - return secretKey, nil + return []byte(tokenSecret), nil }) if err != nil { return 0, err } - claims, ok := parsed.Claims.(*jwt.RegisteredClaims) - if !ok || !parsed.Valid { - return 0, jwt.ErrTokenInvalidClaims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return 0, fmt.Errorf("invalid token") } - id, err := strconv.ParseUint(claims.Subject, 10, 32) - if err != nil { - return 0, err + exp, ok := claims["exp"].(float64) + if !ok || time.Now().Unix() > int64(exp) { + return 0, fmt.Errorf("token expired") } + + sub, ok := claims["sub"].(string) + if !ok { + return 0, fmt.Errorf("invalid subject claim") + } + + id, err := strconv.ParseUint(sub, 10, 32) + if err != nil { + return 0, fmt.Errorf("invalid subject claim") + } + return uint32(id), nil } diff --git a/wsServer.go b/wsServer.go index 15c96fe..a8e207e 100644 --- a/wsServer.go +++ b/wsServer.go @@ -23,85 +23,95 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - client := Client{} + var client = Client{WsConn: connection} + var isAuthenticated bool + defer closeConnection(&client) for { var clientMessage map[string]any err := wsjson.Read(ctx, connection, &clientMessage) if err != nil { - log.Printf("read error: %clientMessage", err) + log.Printf("read error: %v", err) return } if len(clientMessage) > 0 { - if client.IsAuthenticated { - handleAuthenticatedMessage(connection, &client, &clientMessage) - } else { - if !handleUnauthenticatedMessage(connection, &client, &clientMessage) { - closeConnection(connection) + if isAuthenticated { + if !handleAuthenticatedMessage(&client, &clientMessage) { return } + } else { + if !handleUnauthenticatedMessage(&client, &clientMessage) { + return + } + isAuthenticated = true } } } } -func sendMessageCloseIfTimeout(conn *websocket.Conn, message *map[string]any) { +func sendMessageCloseIfTimeout(client *Client, message *map[string]any) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - err := wsjson.Write(ctx, conn, message) + err := wsjson.Write(ctx, client.WsConn, message) if err != nil { if errors.Is(err, context.DeadlineExceeded) { - closeConnection(conn) + closeConnection(client) } log.Printf("write error: %v", err) } } -func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) bool { - token, ok := (*message)["token"].(string) - if !ok { - var errmsg = map[string]any{ - "type": WsServerResponse(BadMessage), - "message": "token required", - } - sendMessageCloseIfTimeout(conn, &errmsg) - return false +func sendToAllMessageCloseIfTimeout(message *map[string]any) { + mu.RLock() + defer mu.RUnlock() + for _, client := range CacheClients { + sendMessageCloseIfTimeout(client, message) } +} - clientId, err := GetClientIdFromToken(token) - if err != nil { - var errmsg = map[string]any{ - "type": WsServerResponse(InvalidCredentials), - "message": "bad token", - } - sendMessageCloseIfTimeout(conn, &errmsg) - return false - } - - client, err = GetClientFromId(clientId) - if err != nil { - var errmsg = map[string]any{ - "type": WsServerResponse(InvalidCredentials), - "message": "bad token", - } - sendMessageCloseIfTimeout(conn, &errmsg) - return false - } - - client.IsAuthenticated = true +func handleAuthenticatedMessage(client *Client, clientMessage *map[string]any) bool { + sendToAllMessageCloseIfTimeout(clientMessage) return true } -func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) { - for _, sendTo := range clients { - if sendTo.IsAuthenticated && sendTo.Id != client.Id { - sendMessageCloseIfTimeout(conn, message) +func handleUnauthenticatedMessage(client *Client, clientMessage *map[string]any) bool { + token, ok := (*clientMessage)["token"].(string) + if !ok { + var msg = map[string]any{ + "from": "server", + "error": "no token in message", } + sendMessageCloseIfTimeout(client, &msg) + return false } + + clientId, err := TokenValidateGetId(token) + if err != nil { + var msg = map[string]any{ + "from": "server", + "error": "invalid token", + } + sendMessageCloseIfTimeout(client, &msg) + return false + } + + clientFromCache, err := CacheGetClientById(clientId) + if err != nil { + var msg = map[string]any{ + "from": "server", + "error": "invalid token", + } + sendMessageCloseIfTimeout(client, &msg) + return false + } + + *client = *clientFromCache + return true } -func closeConnection(conn *websocket.Conn) { - conn.Close(websocket.StatusNormalClosure, "closing connection") +func closeConnection(client *Client) { + CacheDeleteClient(client.Id) + client.WsConn.CloseNow() }