diff --git a/cache.go b/cache.go index baa4cf5..120948a 100644 --- a/cache.go +++ b/cache.go @@ -1,8 +1,11 @@ package main import ( + "context" "errors" "sync" + + "golang.org/x/crypto/bcrypt" ) var ( @@ -11,32 +14,48 @@ var ( chatGroups = make(map[uint32]ChatGroup) ) -func AddGroupToCache(chatGroup *ChatGroup) { +func CreateGroup(ctx context.Context, chatGroup *ChatGroup) { mu.Lock() defer mu.Unlock() chatGroups[chatGroup.Id] = *chatGroup } -func RemoveGroupFromCache(chatGroup *ChatGroup) { +func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) { mu.Lock() defer mu.Unlock() delete(chatGroups, chatGroup.Id) } -func AddClientConnectionsToCache(client *Client) { +func CreateClient(ctx context.Context, client *Client) error { mu.Lock() defer mu.Unlock() - for _, groupIn := range client.Groups { - chatGroups[groupIn.Id].Members[client.Id] = client + err := SaveClientWithoutGroups(ctx, client) + if err != nil { + return err } + return nil } -func RemoveClientConnectionsToCache(client *Client) { - mu.Lock() - defer mu.Unlock() - for _, groupIn := range client.Groups { - delete(chatGroups[groupIn.Id].Members, client.Id) +func CheckPassword(ctx context.Context, id uint32, password string) error { + client, err := GetClientFromId(id) + if err != nil { + return err } + + err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password)) + if err != nil { + return err + } + return nil +} + +func GetIdFromClientName(ctx context.Context, name string) (uint32, error) { + for _, client := range clients { + if client.Name == name { + return client.Id, nil + } + } + return 0, errors.New("client not found") } func GetClientFromId(id uint32) (*Client, error) { @@ -46,3 +65,19 @@ func GetClientFromId(id uint32) (*Client, error) { } return client, nil } + +func ConnectClientToGroups(ctx context.Context, client *Client) { + mu.Lock() + defer mu.Unlock() + for _, groupIn := range client.Groups { + chatGroups[groupIn.Id].Members[client.Id] = client + } +} + +func DisconnectClientFromGroups(ctx context.Context, client *Client) { + mu.Lock() + defer mu.Unlock() + for _, groupIn := range client.Groups { + delete(chatGroups[groupIn.Id].Members, client.Id) + } +} diff --git a/database.go b/database.go index 8a97881..59c15d6 100644 --- a/database.go +++ b/database.go @@ -3,9 +3,8 @@ package main import ( "context" - //"golang.org/x/crypto/bcrypt" - "github.com/jackc/pgx/v5" + "golang.org/x/crypto/bcrypt" ) var dbConn *pgx.Conn @@ -21,6 +20,7 @@ func InitDatabase(ctx context.Context) { id SERIAL PRIMARY KEY, name VARCHAR(20) UNIQUE NOT NULL, pass_hash VARCHAR(60) NOT NULL, + pronouns VARCHAR(15) DEFAULT NULL, color VARCHAR(3) DEFAULT NULL ); CREATE TABLE IF NOT EXISTS chat_groups ( @@ -45,3 +45,24 @@ func InitDatabase(ctx context.Context) { dbConn = conn } + +func SaveClientWithoutGroups(ctx context.Context, client *Client) error { + var id uint64 + var err error + + var hashed []byte + hashed, err = bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost) + if err != nil { + return err + } + + password := string(hashed) + + c := string(client.Color[:]) + err = dbConn.QueryRow(ctx, ` + INSERT INTO users (name, pass_hash, pronouns, color) + VALUES ($1, $2, $3) + RETURNING id + `, client.Name, password, client.Pronouns, c).Scan(&id) + return err +} diff --git a/enums.go b/enums.go index f12b525..aaf8639 100644 --- a/enums.go +++ b/enums.go @@ -1,8 +1,8 @@ package main -type WSServerResponse uint8 +type WsServerResponse uint8 const ( - BadMessage WSServerResponse = iota + BadMessage WsServerResponse = iota InvalidCredentials ) diff --git a/go-socket b/go-socket index 5adad46..021e86d 100755 Binary files a/go-socket and b/go-socket differ diff --git a/http.go b/http.go index a694a25..7702d12 100644 --- a/http.go +++ b/http.go @@ -16,4 +16,65 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) { } ctx := request.Context() + + username := request.FormValue("username") + if len(username) < 4 { + http.Error(response, "no or short username", http.StatusBadRequest) + return + } + + password := request.FormValue("password") + if len(password) < 8 { + http.Error(response, "no or short password", http.StatusBadRequest) + return + } + + newClient := Client{ + Name: username, + PasswordHash: password, + } + + err := CreateClient(ctx, &newClient) + if err != nil { + http.Error(response, "taken", http.StatusBadRequest) + return + } + response.Write([]byte("registered")) +} + +func LoginHandler(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + + ctx := request.Context() + + username := request.FormValue("username") + if len(username) < 4 { + http.Error(response, "no or short username", http.StatusBadRequest) + return + } + + password := request.FormValue("password") + if len(password) < 8 { + http.Error(response, "no or short password", http.StatusBadRequest) + return + } + + id, err := GetIdFromClientName(ctx, username) + if err != nil { + http.Error(response, "bad login", http.StatusBadRequest) + } + + err = CheckPassword(ctx, id, password) + if err != nil { + http.Error(response, "bad login", http.StatusBadRequest) + } + + token, err := GetToken(id) + if err != nil { + http.Error(response, "Internal error", http.StatusInternalServerError) + } + + response.Write([]byte(token)) } diff --git a/main.go b/main.go index 7905807..45fa4c3 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,19 @@ package main -func main() { +import ( + "context" + "log" + "net/http" +) +func main() { + ctx := context.Background() + InitDatabase(ctx) + + http.HandleFunc("/ws", ServeWsConnection) + http.HandleFunc("/register", RegisterHandler) + http.HandleFunc("/login", LoginHandler) + + log.Println("listening on :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/struct.go b/struct.go index 7ea3473..db48d37 100644 --- a/struct.go +++ b/struct.go @@ -3,7 +3,7 @@ package main import "github.com/coder/websocket" type Client struct { - Password string + PasswordHash string Name string Pronouns string Groups [12]*ChatGroup diff --git a/tokens.go b/tokens.go index 8dc4506..fc2415b 100644 --- a/tokens.go +++ b/tokens.go @@ -9,10 +9,10 @@ import ( var secretKey = []byte("replace-with-env-variable") -func GetToken(client *Client) (string, error) { +func GetToken(clientId uint32) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - Subject: strconv.Itoa(int(client.Id)), + Subject: strconv.Itoa(int(clientId)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, diff --git a/wsServer.go b/wsServer.go index 9c0f013..f68f924 100644 --- a/wsServer.go +++ b/wsServer.go @@ -11,7 +11,7 @@ import ( "github.com/coder/websocket/wsjson" ) -func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) { +func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) { connection, err := websocket.Accept(responseWriter, request, nil) if err != nil { log.Printf("websocket accept error: %v", err) @@ -33,7 +33,7 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) if len(clientMessage) > 0 { if client.IsAuthenticated { - handleAuthenticatedMessage() + handleAuthenticatedMessage(connection, &client, &clientMessage) } else { if !handleUnauthenticatedMessage(connection, &client, &clientMessage) { closeConnection(connection) @@ -61,7 +61,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message token, ok := (*message)["token"].(string) if !ok { var errmsg = map[string]any{ - "type": WSServerResponse(BadMessage), + "type": WsServerResponse(BadMessage), "message": "token required", } sendMessageCloseIfTimeout(conn, &errmsg) @@ -71,7 +71,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message clientId, err := GetClientIdFromToken(token) if err != nil { var errmsg = map[string]any{ - "type": WSServerResponse(InvalidCredentials), + "type": WsServerResponse(InvalidCredentials), "message": "bad token", } sendMessageCloseIfTimeout(conn, &errmsg) @@ -81,7 +81,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message client, err = GetClientFromId(clientId) if err != nil { var errmsg = map[string]any{ - "type": WSServerResponse(InvalidCredentials), + "type": WsServerResponse(InvalidCredentials), "message": "bad token", } sendMessageCloseIfTimeout(conn, &errmsg) @@ -92,8 +92,12 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message return true } -func handleAuthenticatedMessage() { - +func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) { + for _, sendTo := range clients { + if sendTo.IsAuthenticated && sendTo.Id != client.Id { + sendMessageCloseIfTimeout(conn, message) + } + } } func closeConnection(conn *websocket.Conn) {