diff --git a/Enums/ConnectionState/ConnectionState.go b/Enums/ConnectionState/ConnectionState.go new file mode 100644 index 0000000..f0bee24 --- /dev/null +++ b/Enums/ConnectionState/ConnectionState.go @@ -0,0 +1,10 @@ +package ConnectionState + +type ConnectionState uint8 + +const ( + Stranger ConnectionState = iota + GroupFellow + Friend + GroupFriend +) diff --git a/Enums/ConnectionType/ConnectionType.go b/Enums/ConnectionType/ConnectionType.go deleted file mode 100644 index ba5fceb..0000000 --- a/Enums/ConnectionType/ConnectionType.go +++ /dev/null @@ -1,10 +0,0 @@ -package ConnectionType - -type ConnectionType uint8 - -const ( - Stranger ConnectionType = iota - GroupFellow - Friend - GroupFriend -) diff --git a/database.go b/database.go index 8563c87..ab59847 100644 --- a/database.go +++ b/database.go @@ -42,8 +42,8 @@ func DbInit(ctx context.Context) { id UUID PRIMARY KEY DEFAULT gen_random_uuid(), requestor_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, recipient_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - state TINYINT NOT NULL DEFAULT 0 - created_at TIMESTAMP NOT NULL DEFAULT NOW(), + state SMALLINT NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT NOW() ) `) if err != nil { @@ -173,10 +173,10 @@ func DbConnectionSave(ctx context.Context, conn *Connection) error { `, conn.RequestorId, conn.RecipientId, conn.State, conn.CreatedAt).Scan(&conn.Id) } -func DbConnectionGetBelongingToUser(ctx context.Context, user *User) error { +func DbConnectionsGetBelongingToUser(ctx context.Context, user *User) error { rows, err := dbConn.Query(ctx, ` SELECT id, requestor_id, recipient_id, state, created_at - FROM connections + FROM user_connections WHERE requestor_id = $1 OR recipient_id = $1 `, user.Id) if err != nil { @@ -195,10 +195,16 @@ func DbConnectionGetBelongingToUser(ctx context.Context, user *User) error { } user.Connections[conn.Id] = conn } - return rows.Err() } +func DbConnectionSet(ctx context.Context, conn *Connection) error { + _, err := dbConn.Exec(ctx, ` + UPDATE user_connections SET state = $1 WHERE id = $2 + `, conn.State, conn.Id) + return err +} + func DbGroupSave(ctx context.Context, group *Group) error { err := dbConn.QueryRow(ctx, ` INSERT INTO chat_groups (name, creator_id, owner_id, enable_client_colors, color_red, color_green, color_blue, created_at) diff --git a/http.go b/http.go index 0521aaa..f37104b 100644 --- a/http.go +++ b/http.go @@ -23,12 +23,7 @@ func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool return true } -func getUser(ctx context.Context, token string) (*User, error) { - userId, err := TokenValidateGetId(token) - if err != nil { - return nil, err - } - +func getUserById(ctx context.Context, userId uint32) (*User, error) { user, err := CacheGetUserById(userId) if err != nil { user = &User{Id: userId} @@ -38,7 +33,7 @@ func getUser(ctx context.Context, token string) (*User, error) { if err = DbUserGetGroups(ctx, user); err != nil { return nil, err } - if err = DbUserGetConnections(ctx, user); err != nil { + if err = DbConnectionsGetBelongingToUser(ctx, user); err != nil { return nil, err } CacheSaveUser(user) @@ -47,6 +42,14 @@ func getUser(ctx context.Context, token string) (*User, error) { return user, nil } +func getUserByToken(ctx context.Context, token string) (*User, error) { + userId, err := TokenValidateGetId(token) + if err != nil { + return nil, err + } + return getUserById(ctx, userId) +} + func getGroup(ctx context.Context, groupId uint32) (*Group, error) { group, err := CacheGetGroup(groupId) if err != nil { @@ -70,7 +73,7 @@ func isOwner(user *User, group *Group) bool { } func getIfOwnerUserAndGroup(ctx context.Context, response *http.ResponseWriter, request *http.Request) (*User, *Group, error) { - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(*response, "invalid token", http.StatusUnauthorized) return nil, nil, err @@ -143,6 +146,9 @@ func HttpHandleUserNew(response http.ResponseWriter, request *http.Request) { } func HttpHandleUserDelete(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } ctx := request.Context() userId, err := TokenValidateGetId(request.FormValue("token")) @@ -163,8 +169,12 @@ func HttpHandleUserDelete(response http.ResponseWriter, request *http.Request) { // HttpHandleUserModifyAppearance currently just color func HttpHandleUserModifyAppearance(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return @@ -186,8 +196,12 @@ func HttpHandleUserModifyAppearance(response http.ResponseWriter, request *http. // HttpHandleUserModifyAbout currently just pronouns func HttpHandleUserModifyAbout(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return @@ -209,47 +223,17 @@ func HttpHandleUserModifyAbout(response http.ResponseWriter, request *http.Reque } func HttpHandleUserMessage(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } + ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } - targetId, err := ConvertStringUint32(request.FormValue("recipientid")) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - target, err := CacheGetUserById(targetId) - if err != nil { - target = &User{Id: targetId} - err = DbUserGetById(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - err = DbUserGetConnections(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - } - - if user.Connections[target.Id] == nil || !user.Connections[targetId].IsAccepted { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - message := request.FormValue("message") - if message == "" { - http.Error(response, "empty message", http.StatusBadRequest) - return - } - - WsSendToUser(user, target, message) - response.WriteHeader(http.StatusAccepted) } func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Request) { @@ -258,184 +242,49 @@ func HttpHandleUserNewConnection(response http.ResponseWriter, request *http.Req } ctx := request.Context() - - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } - - targetId, err := ConvertStringUint32(request.FormValue("recipientid")) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - if user.Id == targetId { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - target, err := CacheGetUserById(targetId) - if err != nil { - target = &User{Id: targetId} - err = DbUserGetById(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - err = DbUserGetConnections(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - } - if user.Connections[target.Id] != nil { - http.Error(response, "already sent/connected", http.StatusConflict) - return - } - - timeNow := time.Now() - - err = DbConnectionSave(ctx, timeNow, user.Id, targetId, false) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - - user.Connections[target.Id] = &Connection{ - CreatedAt: timeNow, - With: targetId, - IsFromUser: true, - IsAccepted: false, - } - if target.Connections == nil { - target.Connections = make(map[uint32]*Connection) - } - target.Connections[user.Id] = &Connection{ - CreatedAt: timeNow, - With: user.Id, - IsFromUser: false, - IsAccepted: false, - } - - response.WriteHeader(http.StatusCreated) } func HttpHandleUserDeleteConnection(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } - - targetId, err := ConvertStringUint32(request.FormValue("connectedid")) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - target, err := CacheGetUserById(targetId) - if err != nil { - target = &User{Id: targetId} - err = DbUserGetById(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - err = DbUserGetConnections(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - } - if user.Connections[targetId] == nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - if user.Connections[targetId].IsFromUser { - err = DbConnectionDelete(ctx, user.Id, targetId) - } else { - err = DbConnectionDelete(ctx, targetId, user.Id) - } - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - - delete(user.Connections, targetId) - delete(target.Connections, user.Id) - - response.WriteHeader(http.StatusAccepted) } func HttpHandleUserAcceptConnection(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } - - targetId, err := ConvertStringUint32(request.FormValue("connectedid")) - if err != nil || user.Connections[targetId] == nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - - target, err := CacheGetUserById(targetId) - if err != nil { - target = &User{Id: targetId} - err = DbUserGetById(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - err = DbUserGetConnections(ctx, target) - if err != nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - } - if target.Connections[user.Id] == nil { - http.Error(response, "invalid recipient id", http.StatusBadRequest) - return - } - if user.Connections[targetId].IsFromUser { - http.Error(response, "cant accept own request", http.StatusConflict) - return - } - - user.Connections[targetId].IsAccepted = true - target.Connections[user.Id].IsAccepted = true - - err = DbConnectionAccept(ctx, targetId, user.Id) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - response.WriteHeader(http.StatusAccepted) } func HttpHandleUserGetConnections(response http.ResponseWriter, request *http.Request) { + if !isMethodAllowed(&response, request) { + return + } ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return } - - json, err := json2.Marshal(user.Connections) - if err != nil { - http.Error(response, "internal server error", http.StatusInternalServerError) - return - } - response.WriteHeader(http.StatusAccepted) - response.Write(json) } func HttpHandleTokenNew(response http.ResponseWriter, request *http.Request) { @@ -505,7 +354,7 @@ func HttpHandeGroupCreate(response http.ResponseWriter, request *http.Request) { ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return @@ -746,7 +595,7 @@ func HttpHandleGroupMessage(response http.ResponseWriter, request *http.Request) ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusBadRequest) return @@ -791,7 +640,7 @@ func HttpHandleGroupsGetWithoutMembers(response http.ResponseWriter, request *ht ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return @@ -822,7 +671,7 @@ func HttpHandleGroupMembersGet(response http.ResponseWriter, request *http.Reque } ctx := request.Context() - user, err := getUser(ctx, request.FormValue("token")) + user, err := getUserByToken(ctx, request.FormValue("token")) if err != nil { http.Error(response, "invalid token", http.StatusUnauthorized) return diff --git a/structs.go b/structs.go index a4722b8..0e72d7c 100644 --- a/structs.go +++ b/structs.go @@ -1,6 +1,7 @@ package main import ( + "go-socket/Enums/ConnectionState" "time" "github.com/coder/websocket" @@ -25,7 +26,7 @@ type Connection struct { MessagesBuf [MaxDirectMsgCache]*Message `json:"-"` RequestorId uint32 `json:"requestorId"` RecipientId uint32 `json:"recipientId"` - State uint8 `json:"state"` + State ConnectionState.ConnectionState `json:"state"` } type Message struct {