diff --git a/go-socket b/go-socket index fb753af..9b84457 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index dbfbd80..9c9eb79 100644 --- a/main.go +++ b/main.go @@ -4,135 +4,10 @@ import ( "context" "log" "net/http" - "sync" - "time" "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" ) -type wsServer struct { - OnOpen func(ctx context.Context, conn *websocket.Conn) - OnClose func(ctx context.Context, conn *websocket.Conn, err error) - OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) -} - -var ( - unauthenticatedConnections []*websocket.Conn - authenticatedConnections []AuthConnection - mu sync.Mutex -) - -func removeConnectionCache(conn *websocket.Conn) { - mu.Lock() - defer mu.Unlock() - if getConnectionDataIfAuth(conn) != nil { - for i, c := range authenticatedConnections { - if c.connection == conn { - authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] - authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] - return - } - } - } else { - for i, c := range unauthenticatedConnections { - if c == conn { - unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] - unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] - return - } - } - } -} - -func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { - mu.Lock() - defer mu.Unlock() - for _, c := range authenticatedConnections { - if c.connection == conn { - return &c - } - } - return nil - -} - -func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := wsjson.Write(ctx, conn, message); err != nil { - conn.Close(websocket.StatusGoingAway, "Write error") - } -} - -func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) - if err != nil { - log.Println("accept error:", err) - return - } - defer conn.CloseNow() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if s.OnOpen != nil { - s.OnOpen(ctx, conn) - } - - var readErr error - for { - var msg map[string]any - if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { - break - } - if s.OnMessage != nil { - s.OnMessage(ctx, conn, msg) - } - } - - cancel() // cancel before OnClose so any in-flight queries are canceled first - - if s.OnClose != nil { - s.OnClose(ctx, conn, readErr) - } - - conn.Close(websocket.StatusNormalClosure, "done") -} - -func handleUnauthenticatedMessage(ctx context.Context, conn *websocket.Conn, msg map[string]any) { - token := msg["token"].(string) - subject, err := GetSubject(token) - if err != nil { - log.Println("invalid or expired token:", err) - conn.Close(websocket.StatusPolicyViolation, "invalid token") - return - } - user, err := GetUserData(ctx, subject) - if err != nil { - conn.Close(websocket.StatusPolicyViolation, "invalid token") - return - } - mu.Lock() - authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) - mu.Unlock() - sendAndCloseIfFails(conn, map[string]any{ - "authAs": user.Name, - }) -} - -func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { - message := msg["message"].(string) - if message == "" { - sendAndCloseIfFails(conn, map[string]any{ - "error": "no message", - }) - } - -} - func main() { InitDatabase(context.Background()) srv := &wsServer{ diff --git a/wsServer.go b/wsServer.go new file mode 100644 index 0000000..f1c93ff --- /dev/null +++ b/wsServer.go @@ -0,0 +1,132 @@ +package main + +import ( + "context" + "log" + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +type wsServer struct { + OnOpen func(ctx context.Context, conn *websocket.Conn) + OnClose func(ctx context.Context, conn *websocket.Conn, err error) + OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) +} + +var ( + unauthenticatedConnections []*websocket.Conn + authenticatedConnections []AuthConnection + mu sync.Mutex +) + +func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + log.Println("accept error:", err) + return + } + defer conn.CloseNow() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if s.OnOpen != nil { + s.OnOpen(ctx, conn) + } + + var readErr error + for { + var msg map[string]any + if readErr = wsjson.Read(ctx, conn, &msg); readErr != nil { + break + } + if s.OnMessage != nil { + s.OnMessage(ctx, conn, msg) + } + } + + cancel() // cancel before OnClose so any in-flight queries are canceled first + + if s.OnClose != nil { + s.OnClose(ctx, conn, readErr) + } + + conn.Close(websocket.StatusNormalClosure, "done") +} + +func removeConnectionCache(conn *websocket.Conn) { + mu.Lock() + defer mu.Unlock() + if getConnectionDataIfAuth(conn) != nil { + for i, c := range authenticatedConnections { + if c.connection == conn { + authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] + authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] + return + } + } + } else { + for i, c := range unauthenticatedConnections { + if c == conn { + unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] + unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] + return + } + } + } +} + +func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { + mu.Lock() + defer mu.Unlock() + for _, c := range authenticatedConnections { + if c.connection == conn { + return &c + } + } + return nil +} + +func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := wsjson.Write(ctx, conn, message); err != nil { + conn.Close(websocket.StatusGoingAway, "Write error") + } +} + +func handleUnauthenticatedMessage(ctx context.Context, conn *websocket.Conn, msg map[string]any) { + token := msg["token"].(string) + subject, err := GetSubject(token) + if err != nil { + log.Println("invalid or expired token:", err) + conn.Close(websocket.StatusPolicyViolation, "invalid token") + return + } + user, err := GetUserData(ctx, subject) + if err != nil { + conn.Close(websocket.StatusPolicyViolation, "invalid token") + return + } + mu.Lock() + authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) + mu.Unlock() + sendAndCloseIfFails(conn, map[string]any{ + "authAs": user.Name, + }) +} + +func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { + message := msg["message"].(string) + if message == "" { + sendAndCloseIfFails(conn, map[string]any{ + "error": "no message", + }) + } +}