diff --git a/go-socket b/go-socket index 45e67b3..fb753af 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index 355bf4a..dbfbd80 100644 --- a/main.go +++ b/main.go @@ -12,21 +12,29 @@ import ( ) type wsServer struct { - OnOpen func(conn *websocket.Conn) - OnClose func(conn *websocket.Conn, err error) - OnMessage func(conn *websocket.Conn, msg map[string]any) + OnOpen func(ctx context.Context, conn *websocket.Conn) + OnClose func(ctx context.Context, conn *websocket.Conn, err error) + OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) } var ( unauthenticatedConnections []*websocket.Conn - authenticatedConnections []*websocket.Conn + authenticatedConnections []AuthConnection mu sync.Mutex ) -func removeConnection(conn *websocket.Conn) { +func removeConnectionCache(conn *websocket.Conn) { mu.Lock() defer mu.Unlock() - if isConnectionAuthenticated(conn) { + if getConnectionDataIfAuth(conn) != nil { + for i, c := range authenticatedConnections { + if c.connection == conn { + authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] + authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] + return + } + } + } else { for i, c := range unauthenticatedConnections { if c == conn { unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] @@ -37,15 +45,24 @@ func removeConnection(conn *websocket.Conn) { } } -func isConnectionAuthenticated(conn *websocket.Conn) bool { +func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { mu.Lock() defer mu.Unlock() - for _, c := range unauthenticatedConnections { - if c == conn { - return true + for _, c := range authenticatedConnections { + if c.connection == conn { + return &c } } - return false + return nil + +} + +func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := wsjson.Write(ctx, conn, message); err != nil { + conn.Close(websocket.StatusGoingAway, "Write error") + } } func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -58,11 +75,12 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer conn.CloseNow() - if s.OnOpen != nil { - s.OnOpen(conn) - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - ctx := r.Context() + if s.OnOpen != nil { + s.OnOpen(ctx, conn) + } var readErr error for { @@ -71,35 +89,72 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { break } if s.OnMessage != nil { - s.OnMessage(conn, msg) + s.OnMessage(ctx, conn, msg) } } + cancel() // cancel before OnClose so any in-flight queries are canceled first + if s.OnClose != nil { - s.OnClose(conn, readErr) + s.OnClose(ctx, conn, readErr) } conn.Close(websocket.StatusNormalClosure, "done") } +func handleUnauthenticatedMessage(ctx context.Context, conn *websocket.Conn, msg map[string]any) { + token := msg["token"].(string) + subject, err := GetSubject(token) + if err != nil { + log.Println("invalid or expired token:", err) + conn.Close(websocket.StatusPolicyViolation, "invalid token") + return + } + user, err := GetUserData(ctx, subject) + if err != nil { + conn.Close(websocket.StatusPolicyViolation, "invalid token") + return + } + mu.Lock() + authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) + mu.Unlock() + sendAndCloseIfFails(conn, map[string]any{ + "authAs": user.Name, + }) +} + +func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { + message := msg["message"].(string) + if message == "" { + sendAndCloseIfFails(conn, map[string]any{ + "error": "no message", + }) + } + +} + func main() { InitDatabase(context.Background()) srv := &wsServer{ - OnOpen: func(conn *websocket.Conn) { + OnOpen: func(ctx context.Context, conn *websocket.Conn) { log.Println("client connected") - mu.Lock() - unauthenticatedConnections = append(unauthenticatedConnections, conn) - mu.Unlock() + if getConnectionDataIfAuth(conn) != nil { + mu.Lock() + unauthenticatedConnections = append(unauthenticatedConnections, conn) + mu.Unlock() + } }, - OnClose: func(conn *websocket.Conn, err error) { + OnClose: func(ctx context.Context, conn *websocket.Conn, err error) { log.Println("client disconnected:", err) + removeConnectionCache(conn) }, - OnMessage: func(conn *websocket.Conn, msg map[string]any) { + OnMessage: func(ctx context.Context, conn *websocket.Conn, msg map[string]any) { log.Printf("received: %v\n", msg) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := wsjson.Write(ctx, conn, msg); err != nil { - removeConnection(conn) + authConnOrNil := getConnectionDataIfAuth(conn) + if authConnOrNil == nil { + handleUnauthenticatedMessage(ctx, conn, msg) + } else { + handleAuthenticatedMessage(conn, msg) } }, } diff --git a/structures.go b/structures.go index d48c1b0..89dc55c 100644 --- a/structures.go +++ b/structures.go @@ -10,7 +10,7 @@ type User struct { IsPasswordHashed bool } -type authConnection struct { +type AuthConnection struct { connection *websocket.Conn user User }