package main import ( "context" "errors" "log" "net/http" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) { connection, err := websocket.Accept(responseWriter, request, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Printf("websocket accept error: %v", err) return } ctx, cancel := context.WithCancel(context.Background()) defer cancel() var client = Client{WsConn: connection} var isAuthenticated bool var ignoreCache bool defer closeConnection(&client, ignoreCache) for { var clientMessage map[string]any err := wsjson.Read(ctx, connection, &clientMessage) if err != nil { log.Printf("read error: %v", err) return } if len(clientMessage) > 0 { if isAuthenticated { if !handleAuthenticatedMessage(&client, &clientMessage) { return } } else { if !handleUnauthenticatedMessage(ctx, &client, &clientMessage) { ignoreCache = true return } isAuthenticated = true } } } } func sendMessageCloseIfTimeout(client *Client, message *map[string]any) { if client.WsConn == nil { return } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := wsjson.Write(ctx, client.WsConn, message) if err != nil { if errors.Is(err, context.DeadlineExceeded) { closeConnection(client, false) } log.Printf("write error: %v", err) } } func sendToAllMessageCloseIfTimeout(message *map[string]any) { mu.RLock() defer mu.RUnlock() for _, client := range CacheClients { sendMessageCloseIfTimeout(client, message) } } func handleAuthenticatedMessage(client *Client, clientMessage *map[string]any) bool { subject, ok := (*clientMessage)["subject"].(uint32) if !ok { var msg = map[string]any{ "from": "server", "error": "subject invalid", } sendMessageCloseIfTimeout(client, &msg) } content, ok := (*clientMessage)["content"].(string) if !ok { var msg = map[string]any{ "from": "server", "error": "content invalid", } sendMessageCloseIfTimeout(client, &msg) } group, err := CacheGetGroup(subject) if err != nil { var msg = map[string]any{ "from": "server", "error": "subject invalid", } sendMessageCloseIfTimeout(client, &msg) } for groupClientId, _ := range group.Clients { var msg = map[string]any{ "from": "group", "group": group.Id, "sender": client.Name, "content": content, } var groupClient *Client groupClient, err = CacheGetClientById(groupClientId) if err != nil { sendMessageCloseIfTimeout(groupClient, &msg) } } return true } func handleUnauthenticatedMessage(ctx context.Context, client *Client, clientMessage *map[string]any) bool { token, ok := (*clientMessage)["token"].(string) if !ok { var msg = map[string]any{ "from": "server", "error": "no token in message", } sendMessageCloseIfTimeout(client, &msg) return false } clientId, err := TokenValidateGetId(token) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid token", } sendMessageCloseIfTimeout(client, &msg) return false } clientFromCache, err := CacheGetClientById(clientId) if err != nil { // Not in cache — load from database dbClient := &Client{Id: clientId} err = DbSetClientByIdWithoutGroups(ctx, dbClient) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid client data", } sendMessageCloseIfTimeout(client, &msg) return false } err = DbSetClientGroups(ctx, dbClient) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid client data", } sendMessageCloseIfTimeout(client, &msg) return false } dbClient.WsConn = client.WsConn CacheSaveClient(dbClient) clientFromCache = dbClient } *client = *clientFromCache for groupId, _ := range clientFromCache.Groups { _, err = CacheGetGroup(groupId) if err != nil { dbGroup := &Group{Id: groupId} err = DbSetGroupById(ctx, dbGroup) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid client data", } sendMessageCloseIfTimeout(client, &msg) return false } CacheSaveGroup(dbGroup) } } return true } func closeConnection(client *Client, ignoreCache bool) { if !ignoreCache { CacheDeleteClient(client.Id) } client.WsConn.CloseNow() }