diff --git a/cache.go b/cache.go index ebe7aad..1a3efdc 100644 --- a/cache.go +++ b/cache.go @@ -4,6 +4,7 @@ import "sync" var ( mu sync.RWMutex + Clients = make(map[uint32]Client) ChatGroups = make(map[uint32]ChatGroup) ClientsMap = make(map[uint32]map[uint32]*Client) ) @@ -22,7 +23,7 @@ func RemoveGroupFromCache(chatGroup *ChatGroup) { delete(ClientsMap, chatGroup.Id) } -func AddUserToCache(client *Client) { +func AddAuthenticatedClientToCache(client *Client) { mu.Lock() defer mu.Unlock() for _, groupIn := range client.Groups { @@ -41,3 +42,7 @@ func RemoveClientFromCache(client *Client) { } } } + +func GetClientData(uint32 *Client) { + +} diff --git a/database.go b/database.go new file mode 100644 index 0000000..8a97881 --- /dev/null +++ b/database.go @@ -0,0 +1,47 @@ +package main + +import ( + "context" + + //"golang.org/x/crypto/bcrypt" + + "github.com/jackc/pgx/v5" +) + +var dbConn *pgx.Conn + +func InitDatabase(ctx context.Context) { + conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432") + if err != nil { + panic(err) + } + + _, err = conn.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + name VARCHAR(20) UNIQUE NOT NULL, + pass_hash VARCHAR(60) NOT NULL, + color VARCHAR(3) DEFAULT NULL + ); + CREATE TABLE IF NOT EXISTS chat_groups ( + id SERIAL PRIMARY KEY, + name VARCHAR(48) NOT NULL, + creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + enable_user_colors BOOLEAN NOT NULL DEFAULT true, + group_color VARCHAR(3), + created_at TIMESTAMP NOT NULL DEFAULT NOW() + ); + CREATE TABLE IF NOT EXISTS chat_group_members ( + group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + joined_at TIMESTAMP NOT NULL DEFAULT NOW(), + PRIMARY KEY (group_id, user_id) + ); + `) + if err != nil { + panic(err) + } + + dbConn = conn +} diff --git a/go.mod b/go.mod index d513f92..2902dbf 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,15 @@ module go-socket go 1.26 -require github.com/coder/websocket v1.8.14 +require ( + github.com/coder/websocket v1.8.14 + github.com/jackc/pgx/v5 v5.8.0 + golang.org/x/crypto v0.49.0 +) -require github.com/golang-jwt/jwt/v5 v5.3.1 // indirect +require ( + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + golang.org/x/text v0.35.0 // indirect +) diff --git a/struct.go b/struct.go index 0678304..c748b30 100644 --- a/struct.go +++ b/struct.go @@ -3,13 +3,14 @@ package main import "github.com/coder/websocket" type Client struct { - Password string - Name string - Pronouns string - Groups [12]*ChatGroup - Connection *websocket.Conn - Id uint32 - Color [3]byte + Password string + Name string + Pronouns string + Groups [12]*ChatGroup + Connection *websocket.Conn + Id uint32 + Color [3]byte + IsAuthenticated bool } type ChatGroup struct { diff --git a/tokens.go b/tokens.go index 5c13a6c..ac134bd 100644 --- a/tokens.go +++ b/tokens.go @@ -1,8 +1,6 @@ package main import ( - "errors" - "fmt" "strconv" "time" @@ -11,51 +9,45 @@ import ( var secretKey = []byte("replace-with-env-variable") -type UserClaims struct { - Name string `json:"name"` - Color [3]byte `json:"color"` - jwt.RegisteredClaims -} - func GetToken(client *Client) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, - UserClaims{ - Name: client.Name, - Color: client.Color, - RegisteredClaims: jwt.RegisteredClaims{ - Subject: strconv.Itoa(int(client.Id)), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, + jwt.RegisteredClaims{ + Subject: strconv.Itoa(int(client.Id)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), }, ) return token.SignedString(secretKey) } -func SetClientFromToken(client *Client, tokenString string) error { - token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(t *jwt.Token) (interface{}, error) { +func GetDataFromToken(token *string) (uint32, error) { + parsed, err := jwt.ParseWithClaims(*token, &jwt.RegisteredClaims{}, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + return nil, jwt.ErrSignatureInvalid } return secretKey, nil }) if err != nil { - return err + return 0, err } - claims, ok := token.Claims.(*UserClaims) - if !ok || !token.Valid { - return errors.New("invalid token") + claims, ok := parsed.Claims.(*jwt.RegisteredClaims) + if !ok || !parsed.Valid { + return 0, jwt.ErrTokenInvalidClaims } id, err := strconv.ParseUint(claims.Subject, 10, 32) if err != nil { - return fmt.Errorf("invalid subject: %w", err) + return 0, err } + return uint32(id), nil +} - client.Id = uint32(id) - client.Name = claims.Name - client.Color = claims.Color - +func SetClientFromToken(client *Client, token string) error { + id, err := GetDataFromToken(&token) + if err != nil { + return err + } + client.Id = id return nil } diff --git a/wsServer.go b/wsServer.go index 223ba54..112a6eb 100644 --- a/wsServer.go +++ b/wsServer.go @@ -21,13 +21,9 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - var ( - isAuthenticated = false - continueConnection = true - client = Client{} - ) + client := Client{} - for continueConnection { + for { var clientMessage map[string]any err := wsjson.Read(ctx, connection, &clientMessage) if err != nil { @@ -36,10 +32,13 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) } if len(clientMessage) > 0 { - if isAuthenticated { + if client.IsAuthenticated { handleAuthenticatedMessage() } else { - handleUnauthenticatedMessage(connection, &client, &clientMessage, &isAuthenticated, &continueConnection) + if !handleUnauthenticatedMessage(connection, &client, &clientMessage) { + closeConnection(connection) + return + } } } } @@ -58,7 +57,7 @@ func sendMessageCloseIfTimeout(conn *websocket.Conn, message *map[string]any) { } } -func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any, isAuthenticated *bool, continueConnection *bool) { +func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) bool { token, ok := (*message)["token"].(string) if !ok { var errmsg = map[string]any{ @@ -66,7 +65,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message "message": "token required", } sendMessageCloseIfTimeout(conn, &errmsg) - return + return false } err := SetClientFromToken(client, token) @@ -76,11 +75,10 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message "message": "bad token", } sendMessageCloseIfTimeout(conn, &errmsg) - continueConnection = false - return + return false } - isAuthenticated = true - clientInCache, ok := ClientsMap[] + client.IsAuthenticated = true + return true } func handleAuthenticatedMessage() {