diff --git a/go-socket b/go-socket index 8c74214..76d0cd4 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index f036b87..7b07005 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,7 @@ func main() { http.HandleFunc("POST /hub", withCORS(httpRequest.HandleHubCreate)) http.HandleFunc("GET /hub", withCORS(httpRequest.GetHubData)) http.HandleFunc("POST /hub/channel/message", withCORS(httpRequest.HandleHubMessage)) + http.HandleFunc("GET /hub/channel/messages", withCORS(httpRequest.HandleHubChannelGetMessages)) http.HandleFunc("GET /hub/channel", withCORS(httpRequest.GetChannelData)) http.HandleFunc("GET /hubs", withCORS(httpRequest.HandleGetHubs)) http.HandleFunc("GET /hubs/channels", withCORS(httpRequest.HandleGetChannels)) diff --git a/packages/httpRequest/hubs.go b/packages/httpRequest/hubs.go index 9b4ad30..a7052f0 100644 --- a/packages/httpRequest/hubs.go +++ b/packages/httpRequest/hubs.go @@ -3,6 +3,7 @@ package httpRequest import ( "context" "encoding/json" + "go-socket/packages/config" "net/http" "strings" "time" @@ -317,6 +318,81 @@ func HandleHubMessage(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusCreated) } +func HandleHubChannelGetMessages(response http.ResponseWriter, request *http.Request) { + if !validCheckWithResponseOnFail(response, request, normal) { + return + } + ctx := request.Context() + user, hubUser, hub, err := getHubUserIfValidWithResponseOnFail(ctx, response, request) + if err != nil { + return + } + channel, err := getHubChannelIfValidWithResponseOnFail(ctx, response, hub, hubUser, request.FormValue("channel_id")) + if err != nil { + return + } + + channel.Mu.RLock() + canReadHistory := channel.UsersCachedPermissions[user.Id].CanReadHistory() + channel.Mu.RUnlock() + if !canReadHistory { + http.Error(response, "forbidden", http.StatusForbidden) + return + } + + before, err := convertions.StringToTimestamp(request.URL.Query().Get("before")) + if err != nil { + before = time.Now() + } + + messagesCap, err := convertions.StringToUint32(request.URL.Query().Get("messages")) + if err != nil { + messagesCap = config.MaxDirectMsgCache + } + + buffer, bufferSize := channel.GetSortedMessagesBuff() + + var validBufCount uint32 + for validBufCount < bufferSize && buffer[validBufCount].CreatedAt.Before(before) { + validBufCount++ + } + + var messages []*types.Message + + if validBufCount >= messagesCap { + start := validBufCount - messagesCap + messages = make([]*types.Message, messagesCap) + for i := uint32(0); i < messagesCap; i++ { + messages[i] = buffer[start+i] + } + } else { + remaining := messagesCap - validBufCount + cutoff := before + if validBufCount > 0 { + cutoff = buffer[0].CreatedAt + } + dbMessages, err := postgresql.HubChannelGetMessagesBefore(ctx, cutoff, channel.Id, remaining) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + messages = make([]*types.Message, 0, uint32(len(dbMessages))+validBufCount) + messages = append(messages, dbMessages...) + for i := uint32(0); i < validBufCount; i++ { + messages = append(messages, buffer[i]) + } + } + + data, err := json.Marshal(messages) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + response.WriteHeader(http.StatusOK) + response.Write(data) +} + func HandleGetChannels(response http.ResponseWriter, request *http.Request) { if !validCheckWithResponseOnFail(response, request, normal) { return diff --git a/packages/postgresql/postgresql.go b/packages/postgresql/postgresql.go index 0a15b50..481fce7 100644 --- a/packages/postgresql/postgresql.go +++ b/packages/postgresql/postgresql.go @@ -578,6 +578,35 @@ func HubChannelMessageGet(ctx context.Context, message *types.Message) error { `, message.Id).Scan(&message.Sender, &message.Receiver, &message.CreatedAt, &message.Content, &message.AttachedFile) } +func HubChannelGetMessagesBefore(ctx context.Context, before time.Time, channelId uuid.UUID, cap uint32) ([]*types.Message, error) { + rows, err := dbConn.Query(ctx, ` + SELECT id, sender_id, receiver_id, created_at, content, attached_file + FROM ( + SELECT id, sender_id, receiver_id, created_at, content, attached_file + FROM hub_channel_messages + WHERE receiver_id = $1 + AND created_at < $2 + ORDER BY created_at DESC + LIMIT $3 + ) sub + ORDER BY created_at ASC + `, channelId, before, cap) + if err != nil { + return nil, err + } + defer rows.Close() + + messages := make([]*types.Message, 0, cap) + for rows.Next() { + msg := &types.Message{} + if err = rows.Scan(&msg.Id, &msg.Sender, &msg.Receiver, &msg.CreatedAt, &msg.Content, &msg.AttachedFile); err != nil { + return nil, err + } + messages = append(messages, msg) + } + return messages, rows.Err() +} + func HubUpdate(ctx context.Context, hub *types.Hub, updateList *types.HubUpdate) error { setClauses := make([]string, 0, 6) args := make([]any, 0, 7) diff --git a/test-client/index.html b/test-client/index.html index f3b1857..b4cbf59 100644 --- a/test-client/index.html +++ b/test-client/index.html @@ -43,6 +43,7 @@ + @@ -51,6 +52,7 @@ + @@ -242,6 +244,16 @@
+ +