package main import ( "context" "log" "net/http" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) type wsServer struct { OnOpen func(ctx context.Context, conn *websocket.Conn) OnClose func(ctx context.Context, conn *websocket.Conn, err error) OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) } var ( unauthenticatedConnections []*websocket.Conn authenticatedConnections []AuthConnection mu sync.Mutex ) func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Println("accept error:", err) return } defer conn.CloseNow() ctx, cancel := context.WithCancel(context.Background()) defer cancel() if s.OnOpen != nil { s.OnOpen(ctx, conn) } var readErr error for { var msg map[string]any if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { break } if s.OnMessage != nil { s.OnMessage(ctx, conn, msg) } } cancel() // cancel before OnClose so any in-flight queries are canceled first if s.OnClose != nil { s.OnClose(ctx, conn, readErr) } conn.Close(websocket.StatusNormalClosure, "done") } func removeConnectionCache(conn *websocket.Conn) { mu.Lock() defer mu.Unlock() if getConnectionDataIfAuth(conn) != nil { for i, c := range authenticatedConnections { if c.connection == conn { authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] return } } } else { for i, c := range unauthenticatedConnections { if c == conn { unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] return } } } } func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { mu.Lock() defer mu.Unlock() for _, c := range authenticatedConnections { if c.connection == conn { return &c } } return nil } func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := wsjson.Write(ctx, conn, message); err != nil { conn.Close(websocket.StatusGoingAway, "Write error") } } func sendToAllExceptAndCloseIfFails(conn *websocket.Conn, message map[string]any) { _, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() for _, aConn := range authenticatedConnections { if aConn.connection != conn { sendAndCloseIfFails(aConn.connection, message) } } } func handleUnauthenticatedMessage(conn *websocket.Conn, msg map[string]any) { token := msg["token"].(string) user, err := GetUserFromToken(token) if err != nil { log.Println("invalid or expired token:", err) err := conn.Close(websocket.StatusPolicyViolation, "invalid token") if err != nil { return } return } mu.Lock() authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) mu.Unlock() sendAndCloseIfFails(conn, map[string]any{ "authAs": user.Name, }) } func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { message := msg["message"].(string) if message == "" { sendAndCloseIfFails(conn, map[string]any{ "error": "no message", }) return } auth := getConnectionDataIfAuth(conn) if auth == nil { sendAndCloseIfFails(conn, map[string]any{ "error": "no auth", }) return } sendToAllExceptAndCloseIfFails(conn, map[string]any{ "username": auth.user.Name, "message": message, }) }