Files
go-socket/main.go
T
2026-03-11 08:51:33 +01:00

111 lines
2.4 KiB
Go

package main
import (
"context"
"log"
"net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
type wsServer struct {
OnOpen func(conn *websocket.Conn)
OnClose func(conn *websocket.Conn, err error)
OnMessage func(conn *websocket.Conn, msg map[string]any)
}
var (
unauthenticatedConnections []*websocket.Conn
authenticatedConnections []*websocket.Conn
mu sync.Mutex
)
func removeConnection(conn *websocket.Conn) {
mu.Lock()
defer mu.Unlock()
if isConnectionAuthenticated(conn) {
for i, c := range unauthenticatedConnections {
if c == conn {
unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1]
unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1]
return
}
}
}
}
func isConnectionAuthenticated(conn *websocket.Conn) bool {
mu.Lock()
defer mu.Unlock()
for _, c := range unauthenticatedConnections {
if c == conn {
return true
}
}
return false
}
func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
InsecureSkipVerify: true,
})
if err != nil {
log.Println("accept error:", err)
return
}
defer conn.CloseNow()
if s.OnOpen != nil {
s.OnOpen(conn)
}
ctx := r.Context()
var readErr error
for {
var msg map[string]any
if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil {
break
}
if s.OnMessage != nil {
s.OnMessage(conn, msg)
}
}
if s.OnClose != nil {
s.OnClose(conn, readErr)
}
conn.Close(websocket.StatusNormalClosure, "done")
}
func main() {
InitDatabase(context.Background())
srv := &wsServer{
OnOpen: func(conn *websocket.Conn) {
log.Println("client connected")
mu.Lock()
unauthenticatedConnections = append(unauthenticatedConnections, conn)
mu.Unlock()
},
OnClose: func(conn *websocket.Conn, err error) {
log.Println("client disconnected:", err)
},
OnMessage: func(conn *websocket.Conn, msg map[string]any) {
log.Printf("received: %v\n", msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := wsjson.Write(ctx, conn, msg); err != nil {
removeConnection(conn)
}
},
}
http.Handle("/ws", srv)
log.Println("server listening on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}