diff --git a/database.go b/database.go index cca7171..d99131d 100644 --- a/database.go +++ b/database.go @@ -19,11 +19,26 @@ func InitDatabase(ctx context.Context) { _, err = conn.Exec(ctx, ` CREATE TABLE IF NOT EXISTS users ( - Id SERIAL PRIMARY KEY, - Name VARCHAR(20) UNIQUE NOT NULL, - PassHash VARCHAR(60) NOT NULL, - Color VARCHAR(3) NOT NULL - ) + id SERIAL PRIMARY KEY, + name VARCHAR(20) UNIQUE NOT NULL, + pass_hash VARCHAR(60) NOT NULL, + color VARCHAR(3) NOT NULL + ); + CREATE TABLE IF NOT EXISTS chat_groups ( + id SERIAL PRIMARY KEY, + name VARCHAR(48) NOT NULL, + createor INTEGER NOT NULL REFERANCE users(id) ON DELETE SET NULL, + owner INTEGER NOT NULL REFERANCE user(id) ON DELETE SET NULL + enable_user_colors BOOLEAN NOT NULL DEFAULT true, + group_color VARCHAR(3), + created_at TIMESTAMP NOT NULL DEFAULT NOW() + ); + CREATE TABLE IF NOT EXISTS chat_group_members ( + group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + joined_at TIMESTAMP NOT NULL DEFAULT NOW(), + PRIMARY KEY (group_id, user_id) + ); `) if err != nil { panic(err) @@ -54,16 +69,16 @@ func AddNewUser(ctx context.Context, user User) (uint32, error) { return 0, errors.New("color invalid") } err = dbConnection.QueryRow(ctx, ` - INSERT INTO users (Name, PassHash, Color) + INSERT INTO users (name, pass_hash, color) VALUES ($1, $2, $3) - RETURNING Id + RETURNING id `, user.Name, user.Password, user.Color).Scan(&id) return id, err } func isPassValid(ctx context.Context, id uint32, plainPassword string) bool { var controlHash string - err := dbConnection.QueryRow(ctx, "SELECT PassHash FROM users WHERE Id = $1", id).Scan(&controlHash) + err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash) if err != nil { return false } @@ -73,7 +88,7 @@ func isPassValid(ctx context.Context, id uint32, plainPassword string) bool { func GetUserDataById(ctx context.Context, id uint32) (*User, error) { var user User - err := dbConnection.QueryRow(ctx, "SELECT Id, Name, PassHash, Color FROM users WHERE Id = $1", id). + err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). Scan(&user.Id, &user.Name, &user.Password, &user.Color) if err != nil { return &User{}, err @@ -83,7 +98,7 @@ func GetUserDataById(ctx context.Context, id uint32) (*User, error) { } func GetUserDataByName(ctx context.Context, name string) (*User, error) { var user User - err := dbConnection.QueryRow(ctx, "SELECT Id, Name, PassHash, Color FROM users WHERE Name = $1", name). + err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE name = $1", name). Scan(&user.Id, &user.Name, &user.Password, &user.Color) if err != nil { return &User{}, err @@ -91,3 +106,7 @@ func GetUserDataByName(ctx context.Context, name string) (*User, error) { user.IsPasswordHashed = true return &user, nil } + +func CreateGroup() { + +} diff --git a/go-socket b/go-socket index ac05f9e..9cffc85 100755 Binary files a/go-socket and b/go-socket differ diff --git a/main.go b/main.go index 5167c24..5568973 100644 --- a/main.go +++ b/main.go @@ -4,32 +4,23 @@ import ( "context" "log" "net/http" - - "github.com/coder/websocket" ) func main() { InitDatabase(context.Background()) srv := &wsServer{ - OnOpen: func(ctx context.Context, conn *websocket.Conn) { + OnOpen: func(c *Client) { log.Println("client connected") - if getConnectionDataIfAuth(conn) != nil { - mu.Lock() - unauthenticatedConnections = append(unauthenticatedConnections, conn) - mu.Unlock() - } }, - OnClose: func(ctx context.Context, conn *websocket.Conn, err error) { + OnClose: func(c *Client, err error) { log.Println("client disconnected:", err) - removeConnectionCache(conn) }, - OnMessage: func(ctx context.Context, conn *websocket.Conn, msg map[string]any) { + OnMessage: func(c *Client, msg map[string]any) { log.Printf("received: %v\n", msg) - authConnOrNil := getConnectionDataIfAuth(conn) - if authConnOrNil == nil { - handleUnauthenticatedMessage(conn, msg) + if c.User == nil { + handleUnauthenticatedMessage(c, msg) } else { - handleAuthenticatedMessage(conn, msg) + handleAuthenticatedMessage(c, msg) } }, } diff --git a/structures.go b/structures.go index 30d5b92..c38751e 100644 --- a/structures.go +++ b/structures.go @@ -9,8 +9,10 @@ type User struct { Color string IsPasswordHashed bool } - -type AuthConnection struct { - connection *websocket.Conn - user User +type Client struct { + conn *websocket.Conn + User *User +} + +type ChatGroup struct { } diff --git a/wsServer.go b/wsServer.go index 610bc06..225ae9b 100644 --- a/wsServer.go +++ b/wsServer.go @@ -12,19 +12,18 @@ import ( ) type wsServer struct { - OnOpen func(ctx context.Context, conn *websocket.Conn) - OnClose func(ctx context.Context, conn *websocket.Conn, err error) - OnMessage func(ctx context.Context, conn *websocket.Conn, msg map[string]any) + OnOpen func(c *Client) + OnClose func(c *Client, err error) + OnMessage func(c *Client, msg map[string]any) } var ( - unauthenticatedConnections []*websocket.Conn - authenticatedConnections []AuthConnection - mu sync.Mutex + clients []*Client + mu sync.Mutex ) -func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ +func (s *wsServer) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + conn, err := websocket.Accept(responseWriter, request, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { @@ -33,11 +32,16 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer conn.CloseNow() + client := &Client{conn: conn} + mu.Lock() + clients = append(clients, client) + mu.Unlock() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() if s.OnOpen != nil { - s.OnOpen(ctx, conn) + s.OnOpen(client) } var readErr error @@ -47,52 +51,29 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { break } if s.OnMessage != nil { - s.OnMessage(ctx, conn, msg) + s.OnMessage(client, msg) } } cancel() // cancel before OnClose so any in-flight queries are canceled first if s.OnClose != nil { - s.OnClose(ctx, conn, readErr) + s.OnClose(client, readErr) } + mu.Lock() + for i, c := range clients { + if c == client { + clients[i] = clients[len(clients)-1] + clients = clients[:len(clients)-1] + break + } + } + mu.Unlock() + conn.Close(websocket.StatusNormalClosure, "done") } -func removeConnectionCache(conn *websocket.Conn) { - mu.Lock() - defer mu.Unlock() - if getConnectionDataIfAuth(conn) != nil { - for i, c := range authenticatedConnections { - if c.connection == conn { - authenticatedConnections[i] = authenticatedConnections[len(authenticatedConnections)-1] - authenticatedConnections = authenticatedConnections[:len(authenticatedConnections)-1] - return - } - } - } else { - for i, c := range unauthenticatedConnections { - if c == conn { - unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] - unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] - return - } - } - } -} - -func getConnectionDataIfAuth(conn *websocket.Conn) *AuthConnection { - mu.Lock() - defer mu.Unlock() - for _, c := range authenticatedConnections { - if c.connection == conn { - return &c - } - } - return nil -} - func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -101,54 +82,39 @@ func sendAndCloseIfFails(conn *websocket.Conn, message map[string]any) { } } -func sendToAllExceptAndCloseIfFails(conn *websocket.Conn, message map[string]any) { - _, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - for _, aConn := range authenticatedConnections { - if aConn.connection != conn { - sendAndCloseIfFails(aConn.connection, message) +func sendToAllExceptAndCloseIfFails(client *Client, message map[string]any) { + for _, other := range clients { + if other != client { + sendAndCloseIfFails(other.conn, message) } } } -func handleUnauthenticatedMessage(conn *websocket.Conn, msg map[string]any) { +func handleUnauthenticatedMessage(client *Client, msg map[string]any) { token := msg["token"].(string) user, err := GetUserFromToken(token) if err != nil { - log.Println("invalid or expired token:", err) - err := conn.Close(websocket.StatusPolicyViolation, "invalid token") - if err != nil { - return - } + client.conn.Close(websocket.StatusPolicyViolation, "invalid token") return } - mu.Lock() - authenticatedConnections = append(authenticatedConnections, AuthConnection{connection: conn, user: user}) - mu.Unlock() - sendAndCloseIfFails(conn, map[string]any{ + client.User = &user + sendAndCloseIfFails(client.conn, map[string]any{ "authAs": user.Name, }) + log.Println("New User authenticated as: " + user.Name) } -func handleAuthenticatedMessage(conn *websocket.Conn, msg map[string]any) { +func handleAuthenticatedMessage(client *Client, msg map[string]any) { message := msg["message"].(string) if message == "" { - sendAndCloseIfFails(conn, map[string]any{ + sendAndCloseIfFails(client.conn, map[string]any{ "error": "no message", }) return } - auth := getConnectionDataIfAuth(conn) - if auth == nil { - sendAndCloseIfFails(conn, map[string]any{ - "error": "no auth", - }) - return - } - - sendToAllExceptAndCloseIfFails(conn, map[string]any{ - "username": auth.user.Name, + sendToAllExceptAndCloseIfFails(client, map[string]any{ + "username": client.User.Name, "message": message, }) }