diff --git a/Structures.go b/Structures.go deleted file mode 100644 index b0394fd..0000000 --- a/Structures.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -type User struct { - id uint - name string - password string - color string - isPasswordHashed bool -} diff --git a/database.go b/database.go index b885aa0..5360bf0 100644 --- a/database.go +++ b/database.go @@ -2,9 +2,10 @@ package main import ( "context" - "crypto/subtle" "errors" + "golang.org/x/crypto/bcrypt" + "github.com/jackc/pgx/v5" ) @@ -18,10 +19,10 @@ func InitDatabase(ctx context.Context) { _, err = conn.Exec(ctx, ` CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, - name VARCHAR(20) UNIQUE NOT NULL, - pass_hash VARCHAR(255) NOT NULL, - color VARCHAR(3) NOT NULL + Id SERIAL PRIMARY KEY, + Name VARCHAR(20) UNIQUE NOT NULL, + pass_hash VARCHAR(60) NOT NULL, + Color VARCHAR(3) NOT NULL ) `) if err != nil { @@ -31,44 +32,52 @@ func InitDatabase(ctx context.Context) { } func AddNewUser(ctx context.Context, user User) (uint, error) { - if len(user.name) == 0 || len(user.name) > 20 { + var id uint + var err error + + if len(user.Name) == 0 || len(user.Name) > 20 { return 0, errors.New("username bad length") } - if user.isPasswordHashed { - if len(user.password) != 255 { + if user.IsPasswordHashed { + if len(user.Password) != 60 { return 0, errors.New("password bad length") } } else { - // TODO + var hashed []byte + hashed, err = bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) + if err != nil { + return 0, err + } + user.Password = string(hashed) } - if len(user.color) != 1 && len(user.color) != 3 { + if len(user.Color) != 1 && len(user.Color) != 3 { return 0, errors.New("color invalid") } - var id uint - err := dbConnection.QueryRow(ctx, ` - INSERT INTO users (name, pass_hash, color) + err = dbConnection.QueryRow(ctx, ` + INSERT INTO users (Name, pass_hash, Color) VALUES ($1, $2, $3) - RETURNING id - `, user.name, user.isPasswordHashed, user.color).Scan(&id) + RETURNING Id + `, user.Name, user.IsPasswordHashed, user.Color).Scan(&id) return id, err } -func CheckPassword(ctx context.Context, id string, hash string) bool { +func CheckPassword(ctx context.Context, id string, plainPassword string) bool { var controlHash string - err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash) + err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE Id = $1", id).Scan(&controlHash) if err != nil { return false } - return subtle.ConstantTimeCompare([]byte(controlHash), []byte(hash)) == 1 + + return bcrypt.CompareHashAndPassword([]byte(controlHash), []byte(plainPassword)) == nil } func GetUserData(ctx context.Context, id string) (User, error) { var user User - err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). - Scan(&user.id, &user.name, &user.password, &user.color) + err := dbConnection.QueryRow(ctx, "SELECT Id, Name, pass_hash, Color FROM users WHERE Id = $1", id). + Scan(&user.Id, &user.Name, &user.Password, &user.Color) if err != nil { return User{}, err } - user.isPasswordHashed = true + user.IsPasswordHashed = true return user, nil } diff --git a/go-socket b/go-socket index 2a68a41..45e67b3 100755 Binary files a/go-socket and b/go-socket differ diff --git a/go.mod b/go.mod index a7af6df..06a6a5f 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,14 @@ module go-socket go 1.26 require ( - github.com/redis/go-redis/v9 v9.18.0 - nhooyr.io/websocket v1.8.17 + github.com/coder/websocket v1.8.14 + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/jackc/pgx/v5 v5.8.0 + golang.org/x/crypto v0.48.0 ) require ( - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.8.0 // indirect - go.uber.org/atomic v1.11.0 // indirect - golang.org/x/text v0.29.0 // indirect + golang.org/x/text v0.34.0 // indirect ) diff --git a/go.sum b/go.sum index 6daf9b2..73a59b1 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,8 @@ -github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= -github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= -github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= -github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -17,24 +11,22 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= -github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= -github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= -github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= -go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= -go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= -nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index ce613bd..355bf4a 100644 --- a/main.go +++ b/main.go @@ -4,19 +4,51 @@ import ( "context" "log" "net/http" + "sync" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" ) -type Server struct { +type wsServer struct { OnOpen func(conn *websocket.Conn) OnClose func(conn *websocket.Conn, err error) OnMessage func(conn *websocket.Conn, msg map[string]any) } -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +var ( + unauthenticatedConnections []*websocket.Conn + authenticatedConnections []*websocket.Conn + mu sync.Mutex +) + +func removeConnection(conn *websocket.Conn) { + mu.Lock() + defer mu.Unlock() + if isConnectionAuthenticated(conn) { + for i, c := range unauthenticatedConnections { + if c == conn { + unauthenticatedConnections[i] = unauthenticatedConnections[len(unauthenticatedConnections)-1] + unauthenticatedConnections = unauthenticatedConnections[:len(unauthenticatedConnections)-1] + return + } + } + } +} + +func isConnectionAuthenticated(conn *websocket.Conn) bool { + mu.Lock() + defer mu.Unlock() + for _, c := range unauthenticatedConnections { + if c == conn { + return true + } + } + return false +} + +func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) @@ -52,9 +84,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func main() { InitDatabase(context.Background()) - srv := &Server{ + srv := &wsServer{ OnOpen: func(conn *websocket.Conn) { log.Println("client connected") + mu.Lock() + unauthenticatedConnections = append(unauthenticatedConnections, conn) + mu.Unlock() }, OnClose: func(conn *websocket.Conn, err error) { log.Println("client disconnected:", err) @@ -64,7 +99,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := wsjson.Write(ctx, conn, msg); err != nil { - log.Println("write error:", err) + removeConnection(conn) } }, } diff --git a/structures.go b/structures.go new file mode 100644 index 0000000..d48c1b0 --- /dev/null +++ b/structures.go @@ -0,0 +1,16 @@ +package main + +import "github.com/coder/websocket" + +type User struct { + Id uint + Name string + Password string + Color string + IsPasswordHashed bool +} + +type authConnection struct { + connection *websocket.Conn + user User +}