package main import ( "context" "errors" "log" "net/http" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) { connection, err := websocket.Accept(responseWriter, request, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Printf("websocket accept error: %v", err) return } ctx, cancel := context.WithCancel(context.Background()) defer cancel() var user = User{WsConn: connection} var isAuthenticated bool var ignoreCache bool defer closeConnection(&user, ignoreCache) for { var userMessage map[string]any err = wsjson.Read(ctx, connection, &userMessage) if err != nil { log.Printf("read error: %v", err) return } if len(userMessage) > 0 { if isAuthenticated { if !handleAuthenticatedMessage(&user, &userMessage) { return } } else { if !handleUnauthenticatedMessage(ctx, &user, &userMessage) { ignoreCache = true return } isAuthenticated = true } } } } func sendMessageCloseIfTimeout(user *User, message *map[string]any) { if user.WsConn == nil { return } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := wsjson.Write(ctx, user.WsConn, message) if err != nil { if errors.Is(err, context.DeadlineExceeded) { closeConnection(user, false) } log.Printf("write error: %v", err) } } func sendToAllMessageCloseIfTimeout(message *map[string]any) { mu.RLock() defer mu.RUnlock() for _, user := range CacheUsers { sendMessageCloseIfTimeout(user, message) } } func WsSendToUser(from *User, to *User, message string) { var msg = map[string]any{ "type": WsMessageToUserFrom(DirectMessage_), "from": from.Id, "content": message, } sendMessageCloseIfTimeout(from, &msg) } func WsSendToGroup(group *Group, sender *User, message string) error { for groupUserId := range group.Users { groupUser, err := CacheGetUserById(groupUserId) if err != nil || groupUser.Id == sender.Id { continue } var msg = map[string]any{ "type": WsMessageToUserFrom(Group_), "from": group.Id, "sender": sender.Id, "content": message, } sendMessageCloseIfTimeout(groupUser, &msg) } return nil } func handleAuthenticatedMessage(user *User, userMessage *map[string]any) bool { sendMessageCloseIfTimeout(user, userMessage) return true } func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *map[string]any) bool { token, ok := (*userMessage)["token"].(string) if !ok { var msg = map[string]any{ "type": WsMessageToUserFrom(Server_), "error": "no token in message", } sendMessageCloseIfTimeout(user, &msg) return false } userId, err := TokenValidateGetId(token) if err != nil { var msg = map[string]any{ "type": WsMessageToUserFrom(Server_), "error": "invalid token", } sendMessageCloseIfTimeout(user, &msg) return false } userFromCache, err := CacheGetUserById(userId) if err != nil { var msg = map[string]any{ "type": WsMessageToUserFrom(Server_), "error": "user not found", } sendMessageCloseIfTimeout(user, &msg) return false } userFromCache.WsConn = user.WsConn *user = *userFromCache for groupId, _ := range userFromCache.Groups { _, err = CacheGetGroup(groupId) if err != nil { dbGroup := &Group{Id: groupId} err = DbGroupGetById(ctx, dbGroup) if err != nil { var msg = map[string]any{ "type": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg) return false } err = DbGroupGetMembers(ctx, dbGroup) if err != nil { var msg = map[string]any{ "type": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg) return false } CacheSaveGroup(dbGroup) } } return true } func closeConnection(user *User, ignoreCache bool) { if !ignoreCache { CacheDeleteUser(user.Id) } user.WsConn.CloseNow() }