diff --git a/go-socket b/go-socket index 13b2a7a..ac05f9e 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index 5d8b0f6..5167c24 100644 --- a/main.go +++ b/main.go @@ -27,7 +27,7 @@ func main() { log.Printf("received: %v\n", msg) authConnOrNil := getConnectionDataIfAuth(conn) if authConnOrNil == nil { - handleUnauthenticatedMessage(ctx, conn, msg) + handleUnauthenticatedMessage(conn, msg) } else { handleAuthenticatedMessage(conn, msg) } diff --git a/tokens.go b/tokens.go index 4be34f7..c8fea1c 100644 --- a/tokens.go +++ b/tokens.go @@ -12,32 +12,51 @@ import ( var secretKey = []byte("replace-with-env-variable") +type UserClaims struct { + Name string `json:"name"` + Color string `json:"color"` + jwt.RegisteredClaims +} + func GetToken(user *User) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, - jwt.RegisteredClaims{ - Subject: strconv.Itoa(int(user.Id)), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), + UserClaims{ + Name: user.Name, + Color: user.Color, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(int(user.Id)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, }, ) return token.SignedString(secretKey) } -func GetSubject(tokenString string) (string, error) { - token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) { +func GetUserFromToken(tokenString string) (User, error) { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return secretKey, nil }) if err != nil { - return "", err + return User{}, err } - claims, ok := token.Claims.(*jwt.RegisteredClaims) + claims, ok := token.Claims.(*UserClaims) if !ok || !token.Valid { - return "", errors.New("invalid token") + return User{}, errors.New("invalid token") } - return claims.Subject, nil + id, err := strconv.ParseUint(claims.Subject, 10, 32) + if err != nil { + return User{}, fmt.Errorf("invalid subject: %w", err) + } + + return User{ + Id: uint32(id), + Name: claims.Name, + Color: claims.Color, + }, nil } diff --git a/wsServer.go b/wsServer.go index 6ed9bc8..610bc06 100644 --- a/wsServer.go +++ b/wsServer.go @@ -4,7 +4,6 @@ import ( "context" "log" "net/http" - "strconv" "sync" "time" @@ -112,30 +111,19 @@ func sendToAllExceptAndCloseIfFails(conn *websocket.Conn, message map[string]any } } -func handleUnauthenticatedMessage(ctx context.Context, conn *websocket.Conn, msg map[string]any) { +func handleUnauthenticatedMessage(conn *websocket.Conn, msg map[string]any) { token := msg["token"].(string) - subject, err := GetSubject(token) + user, err := GetUserFromToken(token) if err != nil { log.Println("invalid or expired token:", err) - conn.Close(websocket.StatusPolicyViolation, "invalid token") - return - } - - var subjectId uint32 - parsed, err := strconv.ParseUint(subject, 10, 32) - subjectId = uint32(parsed) - if err != nil { - conn.Close(websocket.StatusPolicyViolation, "invalid token") - return - } - - user, err := GetUserDataById(ctx, subjectId) - if err != nil { - conn.Close(websocket.StatusPolicyViolation, "invalid token") + err := conn.Close(websocket.StatusPolicyViolation, "invalid token") + if err != nil { + return + } return } mu.Lock() - authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: *user}) + authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) mu.Unlock() sendAndCloseIfFails(conn, map[string]any{ "authAs": user.Name, @@ -151,8 +139,16 @@ func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { return } + auth := getConnectionDataIfAuth(conn) + if auth == nil { + sendAndCloseIfFails(conn, map[string]any{ + "error": "no auth", + }) + return + } + sendToAllExceptAndCloseIfFails(conn, map[string]any{ - "username": , - "message": message, + "username": auth.user.Name, + "message": message, }) }