diff --git a/.gitignore b/.gitignore index 9dddfc8..4230805 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea go-socket +go.sum \ No newline at end of file diff --git a/Structures.go b/Structures.go new file mode 100644 index 0000000..b0394fd --- /dev/null +++ b/Structures.go @@ -0,0 +1,9 @@ +package main + +type User struct { + id uint + name string + password string + color string + isPasswordHashed bool +} diff --git a/database.go b/database.go new file mode 100644 index 0000000..b885aa0 --- /dev/null +++ b/database.go @@ -0,0 +1,74 @@ +package main + +import ( + "context" + "crypto/subtle" + "errors" + + "github.com/jackc/pgx/v5" +) + +var dbConnection *pgx.Conn + +func InitDatabase(ctx context.Context) { + conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432") + if err != nil { + panic(err) + } + + _, err = conn.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + name VARCHAR(20) UNIQUE NOT NULL, + pass_hash VARCHAR(255) NOT NULL, + color VARCHAR(3) NOT NULL + ) + `) + if err != nil { + panic(err) + } + dbConnection = conn +} + +func AddNewUser(ctx context.Context, user User) (uint, error) { + if len(user.name) == 0 || len(user.name) > 20 { + return 0, errors.New("username bad length") + } + if user.isPasswordHashed { + if len(user.password) != 255 { + return 0, errors.New("password bad length") + } + } else { + // TODO + } + if len(user.color) != 1 && len(user.color) != 3 { + return 0, errors.New("color invalid") + } + var id uint + err := dbConnection.QueryRow(ctx, ` + INSERT INTO users (name, pass_hash, color) + VALUES ($1, $2, $3) + RETURNING id + `, user.name, user.isPasswordHashed, user.color).Scan(&id) + return id, err +} + +func CheckPassword(ctx context.Context, id string, hash string) bool { + var controlHash string + err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash) + if err != nil { + return false + } + return subtle.ConstantTimeCompare([]byte(controlHash), []byte(hash)) == 1 +} + +func GetUserData(ctx context.Context, id string) (User, error) { + var user User + err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id). + Scan(&user.id, &user.name, &user.password, &user.color) + if err != nil { + return User{}, err + } + user.isPasswordHashed = true + return user, nil +} diff --git a/go-socket b/go-socket index 4b14da8..2a68a41 100755 Binary files a/go-socket and b/go-socket differ diff --git a/go.mod b/go.mod index 4ee0756..a7af6df 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,9 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.8.0 // indirect go.uber.org/atomic v1.11.0 // indirect + golang.org/x/text v0.29.0 // indirect ) diff --git a/go.sum b/go.sum index 1c8f8cd..6daf9b2 100644 --- a/go.sum +++ b/go.sum @@ -4,23 +4,37 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/main.go b/main.go index 394f566..ce613bd 100644 --- a/main.go +++ b/main.go @@ -51,6 +51,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func main() { + InitDatabase(context.Background()) srv := &Server{ OnOpen: func(conn *websocket.Conn) { log.Println("client connected") diff --git a/postgres/.gitignore b/postgres/.gitignore new file mode 100644 index 0000000..6320cd2 --- /dev/null +++ b/postgres/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/postgres/docker-compose.yml b/postgres/docker-compose.yml new file mode 100644 index 0000000..2201298 --- /dev/null +++ b/postgres/docker-compose.yml @@ -0,0 +1,10 @@ +services: + db: + image: postgres:17 + environment: + POSTGRES_USER: master + POSTGRES_PASSWORD: secret + ports: + - "5432:5432" + volumes: + - ./data:/var/lib/postgresql/data \ No newline at end of file diff --git a/tokens.go b/tokens.go index d73e4c6..bfb6e26 100644 --- a/tokens.go +++ b/tokens.go @@ -1,101 +1,42 @@ package main import ( - "context" - "crypto/rand" - "crypto/sha256" - "encoding/hex" + "errors" "fmt" "time" - "github.com/redis/go-redis/v9" + "github.com/golang-jwt/jwt/v5" + _ "github.com/golang-jwt/jwt/v5" ) -var rdb *redis.Client +var secretKey = []byte("replace-with-env-variable") -func init() { - rdb = redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", - DB: 0, +func GetToken(userID string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + ) + return token.SignedString(secretKey) +} + +func GetSubject(tokenString string) (string, error) { + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return secretKey, nil }) -} - -func getAndSaveTokenFor(ctx context.Context, userID string, timeToDie time.Duration) string { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - panic(fmt.Sprintf("failed to generate random bytes: %v", err)) - } - token := hex.EncodeToString(bytes) - tokenHash := hashToken(token) - tokenKey := fmt.Sprintf("token:%s", tokenHash) - userKey := fmt.Sprintf("user_tokens:%s", userID) - - now := time.Now() - expiresAt := now.Add(timeToDie) - - pipe := rdb.Pipeline() - pipe.Set(ctx, tokenKey, userID, timeToDie) - // score = expiration unix timestamp - pipe.ZAdd(ctx, userKey, redis.Z{Score: float64(expiresAt.Unix()), Member: tokenHash}) - // remove already-expired members - pipe.ZRemRangeByScore(ctx, userKey, "-inf", fmt.Sprintf("%d", now.Unix())) - // set key expiry to latest possible token death - pipe.ExpireAt(ctx, userKey, expiresAt) - - if _, err := pipe.Exec(ctx); err != nil { - panic(fmt.Sprintf("failed to execute redis pipeline: %v", err)) - } - - return token -} - -func hashToken(plaintext string) string { - h := sha256.Sum256([]byte(plaintext)) - return hex.EncodeToString(h[:]) -} - -// GetUserID resolves a plaintext token to its owning userID. -// Returns an error if the token is missing or expired. -func GetUserID(ctx context.Context, plaintext string) (string, error) { - key := fmt.Sprintf("token:%s", hashToken(plaintext)) - - userID, err := rdb.Get(ctx, key).Result() - if err == redis.Nil { - return "", fmt.Errorf("token not found or expired") - } - return userID, err -} - -// RevokeToken deletes a single token by its plaintext value. -func RevokeToken(ctx context.Context, plaintext, userID string) error { - hashed := hashToken(plaintext) - tokenKey := fmt.Sprintf("token:%s", hashed) - userKey := fmt.Sprintf("user_tokens:%s", userID) - - pipe := rdb.Pipeline() - pipe.Del(ctx, tokenKey) - pipe.SRem(ctx, userKey, hashed) - - _, err := pipe.Exec(ctx) - return err -} - -// RevokeAllUserTokens deletes every token belonging to a userID. -func RevokeAllUserTokens(ctx context.Context, userID string) error { - userKey := fmt.Sprintf("user_tokens:%s", userID) - - hashedTokens, err := rdb.SMembers(ctx, userKey).Result() if err != nil { - return err + return "", err } - pipe := rdb.Pipeline() - for _, hashed := range hashedTokens { - pipe.Del(ctx, fmt.Sprintf("token:%s", hashed)) + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok || !token.Valid { + return "", errors.New("invalid token") } - pipe.Del(ctx, userKey) - _, err = pipe.Exec(ctx) - return err + return claims.Subject, nil }