102 lines
2.5 KiB
Go
102 lines
2.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
var rdb *redis.Client
|
|
|
|
func init() {
|
|
rdb = redis.NewClient(&redis.Options{
|
|
Addr: "localhost:6379",
|
|
Password: "",
|
|
DB: 0,
|
|
})
|
|
}
|
|
|
|
func getAndSaveTokenFor(ctx context.Context, userID string, timeToDie time.Duration) string {
|
|
bytes := make([]byte, 32)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
panic(fmt.Sprintf("failed to generate random bytes: %v", err))
|
|
}
|
|
token := hex.EncodeToString(bytes)
|
|
tokenHash := hashToken(token)
|
|
tokenKey := fmt.Sprintf("token:%s", tokenHash)
|
|
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
|
|
|
now := time.Now()
|
|
expiresAt := now.Add(timeToDie)
|
|
|
|
pipe := rdb.Pipeline()
|
|
pipe.Set(ctx, tokenKey, userID, timeToDie)
|
|
// score = expiration unix timestamp
|
|
pipe.ZAdd(ctx, userKey, redis.Z{Score: float64(expiresAt.Unix()), Member: tokenHash})
|
|
// remove already-expired members
|
|
pipe.ZRemRangeByScore(ctx, userKey, "-inf", fmt.Sprintf("%d", now.Unix()))
|
|
// set key expiry to latest possible token death
|
|
pipe.ExpireAt(ctx, userKey, expiresAt)
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
panic(fmt.Sprintf("failed to execute redis pipeline: %v", err))
|
|
}
|
|
|
|
return token
|
|
}
|
|
|
|
func hashToken(plaintext string) string {
|
|
h := sha256.Sum256([]byte(plaintext))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// GetUserID resolves a plaintext token to its owning userID.
|
|
// Returns an error if the token is missing or expired.
|
|
func GetUserID(ctx context.Context, plaintext string) (string, error) {
|
|
key := fmt.Sprintf("token:%s", hashToken(plaintext))
|
|
|
|
userID, err := rdb.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
return "", fmt.Errorf("token not found or expired")
|
|
}
|
|
return userID, err
|
|
}
|
|
|
|
// RevokeToken deletes a single token by its plaintext value.
|
|
func RevokeToken(ctx context.Context, plaintext, userID string) error {
|
|
hashed := hashToken(plaintext)
|
|
tokenKey := fmt.Sprintf("token:%s", hashed)
|
|
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
|
|
|
pipe := rdb.Pipeline()
|
|
pipe.Del(ctx, tokenKey)
|
|
pipe.SRem(ctx, userKey, hashed)
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// RevokeAllUserTokens deletes every token belonging to a userID.
|
|
func RevokeAllUserTokens(ctx context.Context, userID string) error {
|
|
userKey := fmt.Sprintf("user_tokens:%s", userID)
|
|
|
|
hashedTokens, err := rdb.SMembers(ctx, userKey).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pipe := rdb.Pipeline()
|
|
for _, hashed := range hashedTokens {
|
|
pipe.Del(ctx, fmt.Sprintf("token:%s", hashed))
|
|
}
|
|
pipe.Del(ctx, userKey)
|
|
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|