package main import ( "context" "log" "net/http" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) type wsServer struct { OnOpen func(ctx context.Context, conn *websocket.Conn) OnClose func(ctx context.Context, conn *websocket.Conn, err error) OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) } var ( unauthenticatedConnections []*websocket.Conn authenticatedConnections []AuthConnection mu sync.Mutex ) func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Println("accept error:", err) return } defer conn.CloseNow() ctx, cancel := context.WithCancel(context.Background()) defer cancel() if s.OnOpen != nil { s.OnOpen(ctx, conn) } var readErr error for { var msg map[string]any if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { break } if s.OnMessage != nil { s.OnMessage(ctx, conn, msg) } } cancel() // cancel before OnClose so any in-flight queries are canceled first if s.OnClose != nil { s.OnClose(ctx, conn, readErr) } conn.Close(websocket.StatusNormalClosure, "done") } func removeConnectionCache(conn *websocket.Conn) { mu.Lock() defer mu.Unlock() if getConnectionDataIfAuth(conn) != nil { for i, c := range authenticatedConnections { if c.connection == conn { authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] return } } } else { for i, c := range unauthenticatedConnections { if c == conn { unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] return } } } } func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { mu.Lock() defer mu.Unlock() for _, c := range authenticatedConnections { if c.connection == conn { return &c } } return nil } func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := wsjson.Write(ctx, conn, message); err != nil { conn.Close(websocket.StatusGoingAway, "Write error") } } func handleUnauthenticatedMessage(ctx context.Context, conn *websocket.Conn, msg map[string]any) { token := msg["token"].(string) subject, err := GetSubject(token) if err != nil { log.Println("invalid or expired token:", err) conn.Close(websocket.StatusPolicyViolation, "invalid token") return } user, err := GetUserData(ctx, subject) if err != nil { conn.Close(websocket.StatusPolicyViolation, "invalid token") return } mu.Lock() authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) mu.Unlock() sendAndCloseIfFails(conn, map[string]any{ "authAs": user.Name, }) } func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { message := msg["message"].(string) if message == "" { sendAndCloseIfFails(conn, map[string]any{ "error": "no message", }) } }