diff --git a/cache.go b/cache.go index 95260ab..ebe7aad 100644 --- a/cache.go +++ b/cache.go @@ -5,7 +5,6 @@ import "sync" var ( mu sync.RWMutex ChatGroups = make(map[uint32]ChatGroup) - Clients = make(map[uint32]Client) ClientsMap = make(map[uint32]map[uint32]*Client) ) diff --git a/enums.go b/enums.go new file mode 100644 index 0000000..f12b525 --- /dev/null +++ b/enums.go @@ -0,0 +1,8 @@ +package main + +type WSServerResponse uint8 + +const ( + BadMessage WSServerResponse = iota + InvalidCredentials +) diff --git a/go.mod b/go.mod index 7f56037..d513f92 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module go-socket go 1.26 require github.com/coder/websocket v1.8.14 + +require github.com/golang-jwt/jwt/v5 v5.3.1 // indirect diff --git a/strcut.go b/strcut.go deleted file mode 100644 index c748b30..0000000 --- a/strcut.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import "github.com/coder/websocket" - -type Client struct { - Password string - Name string - Pronouns string - Groups [12]*ChatGroup - Connection *websocket.Conn - Id uint32 - Color [3]byte - IsAuthenticated bool -} - -type ChatGroup struct { - Name string - Members [32]*Client - Id uint32 - Color [3]byte -} diff --git a/struct.go b/struct.go new file mode 100644 index 0000000..0678304 --- /dev/null +++ b/struct.go @@ -0,0 +1,20 @@ +package main + +import "github.com/coder/websocket" + +type Client struct { + Password string + Name string + Pronouns string + Groups [12]*ChatGroup + Connection *websocket.Conn + Id uint32 + Color [3]byte +} + +type ChatGroup struct { + Name string + Members [32]*Client + Id uint32 + Color [3]byte +} diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..5c13a6c --- /dev/null +++ b/tokens.go @@ -0,0 +1,61 @@ +package main + +import ( + "errors" + "fmt" + "strconv" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var secretKey = []byte("replace-with-env-variable") + +type UserClaims struct { + Name string `json:"name"` + Color [3]byte `json:"color"` + jwt.RegisteredClaims +} + +func GetToken(client *Client) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + UserClaims{ + Name: client.Name, + Color: client.Color, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(int(client.Id)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, + ) + return token.SignedString(secretKey) +} + +func SetClientFromToken(client *Client, tokenString string) error { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return secretKey, nil + }) + if err != nil { + return err + } + + claims, ok := token.Claims.(*UserClaims) + if !ok || !token.Valid { + return errors.New("invalid token") + } + + id, err := strconv.ParseUint(claims.Subject, 10, 32) + if err != nil { + return fmt.Errorf("invalid subject: %w", err) + } + + client.Id = uint32(id) + client.Name = claims.Name + client.Color = claims.Color + + return nil +} diff --git a/wsServer.go b/wsServer.go index da666e6..223ba54 100644 --- a/wsServer.go +++ b/wsServer.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "log" "net/http" "time" @@ -20,25 +21,72 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - for { - var clientMessage any + var ( + isAuthenticated = false + continueConnection = true + client = Client{} + ) + + for continueConnection { + var clientMessage map[string]any err := wsjson.Read(ctx, connection, &clientMessage) if err != nil { log.Printf("read error: %clientMessage", err) return } - log.Printf("received: %clientMessage", clientMessage) - // process and optionally respond - err = wsjson.Write(ctx, connection, map[string]string{"status": "ok"}) - if err != nil { - log.Printf("write error: %clientMessage", err) - return + if len(clientMessage) > 0 { + if isAuthenticated { + handleAuthenticatedMessage() + } else { + handleUnauthenticatedMessage(connection, &client, &clientMessage, &isAuthenticated, &continueConnection) + } } } } -func handleUnauthenticatedMessage() { - +func sendMessageCloseIfTimeout(conn *websocket.Conn, message *map[string]any) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + err := wsjson.Write(ctx, conn, message) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + closeConnection(conn) + } + log.Printf("write error: %v", err) + } +} + +func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any, isAuthenticated *bool, continueConnection *bool) { + token, ok := (*message)["token"].(string) + if !ok { + var errmsg = map[string]any{ + "type": WSServerResponse(BadMessage), + "message": "token required", + } + sendMessageCloseIfTimeout(conn, &errmsg) + return + } + + err := SetClientFromToken(client, token) + if err != nil { + var errmsg = map[string]any{ + "type": WSServerResponse(InvalidCredentials), + "message": "bad token", + } + sendMessageCloseIfTimeout(conn, &errmsg) + continueConnection = false + return + } + isAuthenticated = true + clientInCache, ok := ClientsMap[] +} + +func handleAuthenticatedMessage() { + +} + +func closeConnection(conn *websocket.Conn) { + conn.Close(websocket.StatusNormalClosure, "closing connection") } -func handleAuthenticatedMessage() {}