diff --git a/main.go b/main.go index e0443a0..3db5556 100644 --- a/main.go +++ b/main.go @@ -32,7 +32,9 @@ func main() { http.HandleFunc("/new/connection", withCORS(httpRequest.HandleUserNewConnection)) http.HandleFunc("/new/token", withCORS(httpRequest.HandleUserNewToken)) http.HandleFunc("/new/file", withCORS(httpRequest.HandleAttachmentFileUpload)) - http.HandleFunc("/mod/user/appearence", withCORS(httpRequest.HandleUserModifyAppearance)) + http.HandleFunc("/mod/user/appearence", withCORS(httpRequest.HandleUserModProfile)) + http.HandleFunc("/mod/user/avatar", withCORS(httpRequest.HandleUserModAvatar)) + http.HandleFunc("/mod/user/profilebg", withCORS(httpRequest.HandleUserModProfileBg)) http.HandleFunc("/mod/user/about", withCORS(httpRequest.HandleUserModProfile)) http.HandleFunc("/mod/connection/accept", withCORS(httpRequest.HandleUserElevateConnection)) diff --git a/packages/globals/globals.go b/packages/globals/globals.go index 5a9d428..a0e1d92 100644 --- a/packages/globals/globals.go +++ b/packages/globals/globals.go @@ -7,6 +7,8 @@ const ( FileStorageBucketName string = "communicator" MaxPostBytes uint32 = 4 << 10 MaxPostWithFileBytes uint32 = 1 << 30 + MaxPostWithAvatar uint = 1 << 20 + MaxPostWithProfileBg uint = 4 << 20 FileProcessingPartSize uint64 = 12 << 20 FileProcessingThreads uint = 3 FileDownloadLinkTtl time.Duration = 24 * time.Hour diff --git a/packages/httpRequest/attachmentFile.go b/packages/httpRequest/attachmentFile.go index 769344a..73b1639 100644 --- a/packages/httpRequest/attachmentFile.go +++ b/packages/httpRequest/attachmentFile.go @@ -10,7 +10,7 @@ import ( ) func HandleAttachmentFileUpload(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, true) { + if !postValidCheckWithResponseOnFail(&response, request, postFile) { return } ctx := request.Context() @@ -56,7 +56,7 @@ func HandleAttachmentFileUpload(response http.ResponseWriter, request *http.Requ } func HandleAttachmentFileDownload(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() diff --git a/packages/httpRequest/connectionsAndDms.go b/packages/httpRequest/connectionsAndDms.go index 641c7a2..9f02bc3 100644 --- a/packages/httpRequest/connectionsAndDms.go +++ b/packages/httpRequest/connectionsAndDms.go @@ -21,7 +21,7 @@ import ( ) func HandleDm(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } @@ -89,7 +89,7 @@ func HandleDm(response http.ResponseWriter, request *http.Request) { } func HandleUserGetConnectionMessages(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() @@ -158,7 +158,7 @@ func HandleUserGetConnectionMessages(response http.ResponseWriter, request *http } func HandleUserNewConnection(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() @@ -217,7 +217,7 @@ func HandleUserNewConnection(response http.ResponseWriter, request *http.Request } func HandleUserDeleteConnection(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() @@ -269,7 +269,7 @@ func HandleUserDeleteConnection(response http.ResponseWriter, request *http.Requ } func HandleUserElevateConnection(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() @@ -329,7 +329,7 @@ func HandleUserElevateConnection(response http.ResponseWriter, request *http.Req } func HandleUserGetConnections(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() diff --git a/packages/httpRequest/helper.go b/packages/httpRequest/helper.go index d1fe383..063aa00 100644 --- a/packages/httpRequest/helper.go +++ b/packages/httpRequest/helper.go @@ -1,21 +1,63 @@ package httpRequest import ( + "io" "net/http" + "strings" "go-socket/packages/globals" ) -func postValidCheckWithResponseOnFail(response *http.ResponseWriter, request *http.Request, withFile bool) bool { +type postType uint8 + +const ( + postNormal postType = iota + postFile + postAvatar + postProfileBg +) + +func postValidCheckWithResponseOnFail(response *http.ResponseWriter, request *http.Request, pt postType) bool { if request.Method != http.MethodPost { http.Error(*response, "POST only", http.StatusMethodNotAllowed) return false } - if withFile && request.ContentLength > int64(globals.MaxPostWithFileBytes) || - !withFile && request.ContentLength > int64(globals.MaxPostBytes) { + + var maxSize int64 + switch pt { + case postFile: + maxSize = int64(globals.MaxPostWithFileBytes) + case postAvatar: + maxSize = int64(globals.MaxPostWithAvatar) + case postProfileBg: + maxSize = int64(globals.MaxPostWithProfileBg) + default: + maxSize = int64(globals.MaxPostBytes) + } + + if request.ContentLength > maxSize { http.Error(*response, "Request too large", http.StatusRequestEntityTooLarge) return false } return true } + +func isImage(r io.Reader) (bool, string, error) { + buf := make([]byte, 512) + n, err := io.ReadFull(r, buf) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return false, "", err + } + + contentType := http.DetectContentType(buf[:n]) + isImage := strings.HasPrefix(contentType, "image/") && contentType != "image/svg+xml" + + if seeker, ok := r.(io.Seeker); ok { + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return isImage, contentType, err + } + } + + return isImage, contentType, nil +} diff --git a/packages/httpRequest/user.go b/packages/httpRequest/user.go index 914be8d..8938f33 100644 --- a/packages/httpRequest/user.go +++ b/packages/httpRequest/user.go @@ -2,11 +2,12 @@ package httpRequest import ( json2 "encoding/json" + "net/http" + "time" + "go-socket/packages/convertions" "go-socket/packages/globals" "go-socket/packages/minio" - "net/http" - "time" "go-socket/packages/cache" "go-socket/packages/passwords" @@ -18,7 +19,7 @@ import ( ) func HandleUserNewToken(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } @@ -77,7 +78,7 @@ func HandleUserNewToken(response http.ResponseWriter, request *http.Request) { } func HandleUserNew(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } @@ -118,7 +119,7 @@ func HandleUserNew(response http.ResponseWriter, request *http.Request) { } func HandleUserDelete(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } ctx := request.Context() @@ -140,7 +141,7 @@ func HandleUserDelete(response http.ResponseWriter, request *http.Request) { } func HandleUserModProfile(response http.ResponseWriter, request *http.Request) { - if !postValidCheckWithResponseOnFail(&response, request, false) { + if !postValidCheckWithResponseOnFail(&response, request, postNormal) { return } @@ -190,8 +191,7 @@ func HandleUserModProfile(response http.ResponseWriter, request *http.Request) { } func HandleUserModAvatar(response http.ResponseWriter, request *http.Request) { - - if !postValidCheckWithResponseOnFail(&response, request, true) { + if !postValidCheckWithResponseOnFail(&response, request, postAvatar) { return } ctx := request.Context() @@ -202,7 +202,7 @@ func HandleUserModAvatar(response http.ResponseWriter, request *http.Request) { return } - request.Body = http.MaxBytesReader(response, request.Body, int64(globals.MaxPostWithFileBytes)) + request.Body = http.MaxBytesReader(response, request.Body, int64(globals.MaxPostWithAvatar)) if err = request.ParseMultipartForm(int64(globals.MaxPostBytes)); err != nil { http.Error(response, "invalid multipart form", http.StatusBadRequest) @@ -221,9 +221,21 @@ func HandleUserModAvatar(response http.ResponseWriter, request *http.Request) { } defer file.Close() - contentType := header.Header.Get("Content-Type") - key := minio.GetKey(conn.Id, contentType, minio.File) + isImg, contentType, err := isImage(file) + if err != nil || !isImg { + http.Error(response, "invalid file", http.StatusBadRequest) + return + } + if user.Avatar != "" { + err = minio.Delete(ctx, string(minio.UserAvatarPrefix)+user.Avatar) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + } + + key := minio.GetKey(conn.Id, contentType, minio.File) if err = minio.Upload(ctx, key, file, header.Size, contentType, map[string]string{ "originalName": header.Filename, "uploaderId": user.Id.String(), @@ -231,4 +243,77 @@ func HandleUserModAvatar(response http.ResponseWriter, request *http.Request) { http.Error(response, "upload failed", http.StatusInternalServerError) return } + + user.Avatar = key[len(minio.UserAvatarPrefix):] + err = postgresql.UserUpdateProfile(ctx, user, types.UserProfileUpdateList{Avatar: true}) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + response.WriteHeader(http.StatusAccepted) +} + +func HandleUserModProfileBg(response http.ResponseWriter, request *http.Request) { + if !postValidCheckWithResponseOnFail(&response, request, postProfileBg) { + return + } + ctx := request.Context() + + user, err := getUserByToken(ctx, request.Header.Get("token")) + if err != nil { + http.Error(response, "invalid token", http.StatusUnauthorized) + return + } + + request.Body = http.MaxBytesReader(response, request.Body, int64(globals.MaxPostWithProfileBg)) + + if err = request.ParseMultipartForm(int64(globals.MaxPostBytes)); err != nil { + http.Error(response, "invalid multipart form", http.StatusBadRequest) + return + } + + conn, ok := getConnectionWithResponseOnFail(&response, request, user) + if !ok { + return + } + + file, header, err := request.FormFile("file") + if err != nil { + http.Error(response, "missing file", http.StatusBadRequest) + return + } + defer file.Close() + + isImg, contentType, err := isImage(file) + if err != nil || !isImg { + http.Error(response, "invalid file", http.StatusBadRequest) + return + } + + if user.ProfileBg != "" { + err = minio.Delete(ctx, string(minio.UserProfileBgPrefix)+user.ProfileBg) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + } + + key := minio.GetKey(conn.Id, contentType, minio.UserProfileBg) + if err = minio.Upload(ctx, key, file, header.Size, contentType, map[string]string{ + "originalName": header.Filename, + "uploaderId": user.Id.String(), + }); err != nil { + http.Error(response, "upload failed", http.StatusInternalServerError) + return + } + + user.ProfileBg = key[len(minio.UserProfileBgPrefix):] + err = postgresql.UserUpdateProfile(ctx, user, types.UserProfileUpdateList{ProfileBg: true}) + if err != nil { + http.Error(response, "internal server error", http.StatusInternalServerError) + return + } + + response.WriteHeader(http.StatusAccepted) } diff --git a/packages/minio/minio.go b/packages/minio/minio.go index 86440c7..1a44e5c 100644 --- a/packages/minio/minio.go +++ b/packages/minio/minio.go @@ -25,6 +25,14 @@ const ( UserProfileBg ) +type DataTypePrefix string + +const ( + FilePrefix DataTypePrefix = "upload/" + UserAvatarPrefix DataTypePrefix = "userAvatar/" + UserProfileBgPrefix DataTypePrefix = "userProfileBg/" +) + func GetKey(connectionId uuid.UUID, mimeType string, uploadType DataType) string { extensions, err := mime.ExtensionsByType(mimeType) if err != nil || len(extensions) == 0 { @@ -34,11 +42,11 @@ func GetKey(connectionId uuid.UUID, mimeType string, uploadType DataType) string key := connectionId.String() + "/" + strconv.FormatInt(time.Now().UnixMilli(), 10) + extensions[0] if uploadType == UserAvatar { - return "userAvatar/" + key + return string(UserAvatarPrefix) + key } else if uploadType == UserProfileBg { - return "userProfileBg/" + key + return string(UserProfileBgPrefix) + key } - return "upload/" + key + return string(FilePrefix) + key } func Init(ctx context.Context) { @@ -96,3 +104,8 @@ func DoesExist(ctx context.Context, key string) bool { _, err := minClient.StatObject(ctx, globals.FileStorageBucketName, key, minio.StatObjectOptions{}) return err == nil } + +func Delete(ctx context.Context, key string) error { + err := minClient.RemoveObject(ctx, globals.FileStorageBucketName, key, minio.RemoveObjectOptions{}) + return err +} diff --git a/packages/postgresql/postgresql.go b/packages/postgresql/postgresql.go index d9f0d36..f0b59a8 100644 --- a/packages/postgresql/postgresql.go +++ b/packages/postgresql/postgresql.go @@ -36,7 +36,7 @@ func Init(ctx context.Context) { pronouns TEXT DEFAULT NULL, description TEXT DEFAULT NULL, avatar TEXT DEFAULT NULL, - profileBg TEXT DEFAULT NULL, + profile_bg TEXT DEFAULT NULL, rgba BIGINT NOT NULL DEFAULT 0 CHECK (rgba BETWEEN 0 AND 4294967295), created_at TIMESTAMP NOT NULL DEFAULT NOW() ) @@ -93,8 +93,8 @@ func UserDelete(ctx context.Context, id uuid.UUID) error { func UserGetStandardInfoByName(ctx context.Context, user *types.User) error { var rgba int64 err := dbConn.QueryRow(ctx, ` - SELECT id, name, pass_hash, COALESCE(pronouns, ''), rgba, created_at FROM users WHERE name = $1 - `, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &rgba, &user.CreatedAt) + SELECT id, name, pass_hash, COALESCE(pronouns, ''), rgba, created_at, COALESCE(avatar, ''), COALESCE(profile_bg, '') FROM users WHERE name = $1 + `, user.Name).Scan(&user.Id, &user.Name, &user.PasswordHash, &user.Pronouns, &rgba, &user.CreatedAt, &user.Avatar, &user.ProfileBg) if err == nil { user.Color = convertions.Uint32ToRgba(uint32(rgba)) } @@ -104,8 +104,8 @@ func UserGetStandardInfoByName(ctx context.Context, user *types.User) error { func UserGetById(ctx context.Context, user *types.User) error { var rgba int64 err := dbConn.QueryRow(ctx, ` - SELECT name, pass_hash, COALESCE(pronouns, ''), rgba, created_at FROM users WHERE id = $1 - `, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &rgba, &user.CreatedAt) + SELECT name, pass_hash, COALESCE(pronouns, ''), rgba, created_at, COALESCE(avatar, ''), COALESCE(profile_bg, '') FROM users WHERE id = $1 + `, user.Id).Scan(&user.Name, &user.PasswordHash, &user.Pronouns, &rgba, &user.CreatedAt, &user.Avatar, &user.ProfileBg) if err == nil { user.Color = convertions.Uint32ToRgba(uint32(rgba)) } @@ -132,6 +132,16 @@ func UserUpdateProfile(ctx context.Context, user *types.User, updateList types.U args = append(args, convertions.RgbaToUint32(user.Color)) argIdx++ } + if updateList.Avatar { + setClauses = append(setClauses, fmt.Sprintf("avatar = $%d", argIdx)) + args = append(args, user.Avatar) + argIdx++ + } + if updateList.ProfileBg { + setClauses = append(setClauses, fmt.Sprintf("profile_bg = $%d", argIdx)) + args = append(args, user.ProfileBg) + argIdx++ + } if len(setClauses) == 0 { return nil diff --git a/packages/types/types.go b/packages/types/types.go index 0b6b913..f5a8f88 100644 --- a/packages/types/types.go +++ b/packages/types/types.go @@ -43,6 +43,8 @@ type UserProfileUpdateList struct { Pronouns bool Description bool Color bool + Avatar bool + ProfileBg bool } type Connection struct {