add postgress user handling support
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
.idea
|
||||
go-socket
|
||||
go.sum
|
||||
@@ -0,0 +1,9 @@
|
||||
package main
|
||||
|
||||
type User struct {
|
||||
id uint
|
||||
name string
|
||||
password string
|
||||
color string
|
||||
isPasswordHashed bool
|
||||
}
|
||||
+74
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var dbConnection *pgx.Conn
|
||||
|
||||
func InitDatabase(ctx context.Context) {
|
||||
conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = conn.Exec(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(20) UNIQUE NOT NULL,
|
||||
pass_hash VARCHAR(255) NOT NULL,
|
||||
color VARCHAR(3) NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
dbConnection = conn
|
||||
}
|
||||
|
||||
func AddNewUser(ctx context.Context, user User) (uint, error) {
|
||||
if len(user.name) == 0 || len(user.name) > 20 {
|
||||
return 0, errors.New("username bad length")
|
||||
}
|
||||
if user.isPasswordHashed {
|
||||
if len(user.password) != 255 {
|
||||
return 0, errors.New("password bad length")
|
||||
}
|
||||
} else {
|
||||
// TODO
|
||||
}
|
||||
if len(user.color) != 1 && len(user.color) != 3 {
|
||||
return 0, errors.New("color invalid")
|
||||
}
|
||||
var id uint
|
||||
err := dbConnection.QueryRow(ctx, `
|
||||
INSERT INTO users (name, pass_hash, color)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id
|
||||
`, user.name, user.isPasswordHashed, user.color).Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func CheckPassword(ctx context.Context, id string, hash string) bool {
|
||||
var controlHash string
|
||||
err := dbConnection.QueryRow(ctx, "SELECT pass_hash FROM users WHERE id = $1", id).Scan(&controlHash)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(controlHash), []byte(hash)) == 1
|
||||
}
|
||||
|
||||
func GetUserData(ctx context.Context, id string) (User, error) {
|
||||
var user User
|
||||
err := dbConnection.QueryRow(ctx, "SELECT id, name, pass_hash, color FROM users WHERE id = $1", id).
|
||||
Scan(&user.id, &user.name, &user.password, &user.color)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
user.isPasswordHashed = true
|
||||
return user, nil
|
||||
}
|
||||
@@ -11,5 +11,9 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.8.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
)
|
||||
|
||||
@@ -4,23 +4,37 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y=
|
||||
nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
|
||||
|
||||
@@ -51,6 +51,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
InitDatabase(context.Background())
|
||||
srv := &Server{
|
||||
OnOpen: func(conn *websocket.Conn) {
|
||||
log.Println("client connected")
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
data
|
||||
@@ -0,0 +1,10 @@
|
||||
services:
|
||||
db:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_USER: master
|
||||
POSTGRES_PASSWORD: secret
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- ./data:/var/lib/postgresql/data
|
||||
@@ -1,101 +1,42 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
_ "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var rdb *redis.Client
|
||||
var secretKey = []byte("replace-with-env-variable")
|
||||
|
||||
func init() {
|
||||
rdb = redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Password: "",
|
||||
DB: 0,
|
||||
func GetToken(userID string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
)
|
||||
return token.SignedString(secretKey)
|
||||
}
|
||||
|
||||
func GetSubject(tokenString string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secretKey, nil
|
||||
})
|
||||
}
|
||||
|
||||
func getAndSaveTokenFor(ctx context.Context, userID string, timeToDie time.Duration) string {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
panic(fmt.Sprintf("failed to generate random bytes: %v", err))
|
||||
}
|
||||
token := hex.EncodeToString(bytes)
|
||||
tokenHash := hashToken(token)
|
||||
tokenKey := fmt.Sprintf("token:%s", tokenHash)
|
||||
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(timeToDie)
|
||||
|
||||
pipe := rdb.Pipeline()
|
||||
pipe.Set(ctx, tokenKey, userID, timeToDie)
|
||||
// score = expiration unix timestamp
|
||||
pipe.ZAdd(ctx, userKey, redis.Z{Score: float64(expiresAt.Unix()), Member: tokenHash})
|
||||
// remove already-expired members
|
||||
pipe.ZRemRangeByScore(ctx, userKey, "-inf", fmt.Sprintf("%d", now.Unix()))
|
||||
// set key expiry to latest possible token death
|
||||
pipe.ExpireAt(ctx, userKey, expiresAt)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
panic(fmt.Sprintf("failed to execute redis pipeline: %v", err))
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func hashToken(plaintext string) string {
|
||||
h := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// GetUserID resolves a plaintext token to its owning userID.
|
||||
// Returns an error if the token is missing or expired.
|
||||
func GetUserID(ctx context.Context, plaintext string) (string, error) {
|
||||
key := fmt.Sprintf("token:%s", hashToken(plaintext))
|
||||
|
||||
userID, err := rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return "", fmt.Errorf("token not found or expired")
|
||||
}
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// RevokeToken deletes a single token by its plaintext value.
|
||||
func RevokeToken(ctx context.Context, plaintext, userID string) error {
|
||||
hashed := hashToken(plaintext)
|
||||
tokenKey := fmt.Sprintf("token:%s", hashed)
|
||||
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
||||
|
||||
pipe := rdb.Pipeline()
|
||||
pipe.Del(ctx, tokenKey)
|
||||
pipe.SRem(ctx, userKey, hashed)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// RevokeAllUserTokens deletes every token belonging to a userID.
|
||||
func RevokeAllUserTokens(ctx context.Context, userID string) error {
|
||||
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
||||
|
||||
hashedTokens, err := rdb.SMembers(ctx, userKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
pipe := rdb.Pipeline()
|
||||
for _, hashed := range hashedTokens {
|
||||
pipe.Del(ctx, fmt.Sprintf("token:%s", hashed))
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", errors.New("invalid token")
|
||||
}
|
||||
pipe.Del(ctx, userKey)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
return claims.Subject, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user