package main import ( "context" "errors" "log" "net/http" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) { connection, err := websocket.Accept(responseWriter, request, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { log.Printf("websocket accept error: %v", err) return } ctx, cancel := context.WithCancel(context.Background()) defer cancel() var user = User{WsConn: connection} var isAuthenticated bool var ignoreCache bool defer closeConnection(&user, ignoreCache) for { var userMessage map[string]any err = wsjson.Read(ctx, connection, &userMessage) if err != nil { log.Printf("read error: %v", err) return } if len(userMessage) > 0 { if isAuthenticated { if !handleAuthenticatedMessage(&user, &userMessage) { return } } else { if !handleUnauthenticatedMessage(ctx, &user, &userMessage) { ignoreCache = true return } isAuthenticated = true } } } } func sendMessageCloseIfTimeout(user *User, message *map[string]any) { if user.WsConn == nil { return } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := wsjson.Write(ctx, user.WsConn, message) if err != nil { if errors.Is(err, context.DeadlineExceeded) { closeConnection(user, false) } log.Printf("write error: %v", err) } } func sendToAllMessageCloseIfTimeout(message *map[string]any) { mu.RLock() defer mu.RUnlock() for _, user := range CacheUsers { sendMessageCloseIfTimeout(user, message) } } func WsSendToGroup(ctx context.Context, groupId uint32, senderId uint32, message string) error { group, err := CacheGetGroup(groupId) if err != nil { return errors.New("group invalid") } sender, err := CacheGetUserById(senderId) if err != nil { sender = &User{Id: senderId} err = DbUserSetById(ctx, sender) if err != nil { return errors.New("non existing sender") } } for groupUserId := range group.Users { groupUser, err := CacheGetUserById(groupUserId) if err != nil || groupUser.Id == sender.Id { continue } var msg = map[string]any{ "from": "group", "group": group.Id, "sender": sender.Name, "content": message, } sendMessageCloseIfTimeout(groupUser, &msg) } return nil } func handleAuthenticatedMessage(user *User, userMessage *map[string]any) bool { sendMessageCloseIfTimeout(user, userMessage) return true } func handleUnauthenticatedMessage(ctx context.Context, user *User, userMessage *map[string]any) bool { token, ok := (*userMessage)["token"].(string) if !ok { var msg = map[string]any{ "from": "server", "error": "no token in message", } sendMessageCloseIfTimeout(user, &msg) return false } userId, err := TokenValidateGetId(token) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid token", } sendMessageCloseIfTimeout(user, &msg) return false } userFromCache, err := CacheGetUserById(userId) if err != nil { dbUser := &User{Id: userId} err = DbUserSetByIdWithoutGroups(ctx, dbUser) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg) return false } err = DbUserSetGroups(ctx, dbUser) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg) return false } dbUser.WsConn = user.WsConn CacheSaveUser(dbUser) userFromCache = dbUser } userFromCache.WsConn = user.WsConn *user = *userFromCache for groupId, _ := range userFromCache.Groups { _, err = CacheGetGroup(groupId) if err != nil { dbGroup := &Group{Id: groupId} err = DbGroupSetById(ctx, dbGroup) if err != nil { var msg = map[string]any{ "from": "server", "error": "invalid user data", } sendMessageCloseIfTimeout(user, &msg) return false } CacheSaveGroup(dbGroup) } } return true } func closeConnection(user *User, ignoreCache bool) { if !ignoreCache { CacheDeleteUser(user.Id) } user.WsConn.CloseNow() }