package main import ( "context" "log" "net/http" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) type wsServer struct { OnOpen func(conn *websocket.Conn) OnClose func(conn *websocket.Conn, err error) OnMessage func(conn *websocket.Conn, msg map[string]any) } var ( unauthenticatedConnections []*websocket.Conn authenticatedConnections []*websocket.Conn mu sync.Mutex ) func removeConnection(conn *websocket.Conn) { mu.Lock() defer mu.Unlock() if isConnectionAuthenticated(conn) { for i, c := range unauthenticatedConnections { if c == conn { unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] return } } } } func isConnectionAuthenticated(conn *websocket.Conn) bool { mu.Lock() defer mu.Unlock() for _, c := range unauthenticatedConnections { if c == conn { return true } } return false } func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Println("accept error:", err) return } defer conn.CloseNow() if s.OnOpen != nil { s.OnOpen(conn) } ctx := r.Context() var readErr error for { var msg map[string]any if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { break } if s.OnMessage != nil { s.OnMessage(conn, msg) } } if s.OnClose != nil { s.OnClose(conn, readErr) } conn.Close(websocket.StatusNormalClosure, "done") } func main() { InitDatabase(context.Background()) srv := &wsServer{ OnOpen: func(conn *websocket.Conn) { log.Println("client connected") mu.Lock() unauthenticatedConnections = append(unauthenticatedConnections, conn) mu.Unlock() }, OnClose: func(conn *websocket.Conn, err error) { log.Println("client disconnected:", err) }, OnMessage: func(conn *websocket.Conn, msg map[string]any) { log.Printf("received: %v\n", msg) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := wsjson.Write(ctx, conn, msg); err != nil { removeConnection(conn) } }, } http.Handle("/ws", srv) log.Println("server listening on :8080") log.Fatal(http.ListenAndServe(":8080", nil)) }