new start

This commit is contained in:
2026-03-22 14:26:49 +01:00
parent da5f87d67b
commit d756023952
11 changed files with 242 additions and 237 deletions
+15 -70
View File
@@ -1,92 +1,37 @@
package main package main
import ( import (
"context" "fmt"
"errors"
"sync" "sync"
"golang.org/x/crypto/bcrypt"
) )
var ( var (
mu sync.RWMutex CacheClients = make(map[uint32]*Client)
clients = make(map[uint32]*Client) mu sync.RWMutex
chatGroups = make(map[uint32]ChatGroup) Groups = make(map[uint32]*Group)
) )
func CreateGroup(ctx context.Context, chatGroup *ChatGroup) { func CacheGetClientById(id uint32) (*Client, error) {
mu.Lock() mu.RLock()
defer mu.Unlock() defer mu.RUnlock()
chatGroups[chatGroup.Id] = *chatGroup
}
func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) { client, ok := CacheClients[id]
mu.Lock()
defer mu.Unlock()
delete(chatGroups, chatGroup.Id)
}
func CreateClient(ctx context.Context, client *Client) error {
mu.Lock()
defer mu.Unlock()
clients[client.Id] = client
hashed, err := bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost)
if err != nil {
return err
}
client.PasswordHash = string(hashed)
//err := DbSaveClientWithoutGroups(ctx, client)
//if err != nil {
// return err
//}
return nil
}
func CheckPassword(ctx context.Context, id uint32, password string) error {
client, err := GetClientFromId(id)
if err != nil {
return err
}
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password))
if err != nil {
return err
}
return nil
}
func GetIdFromClientName(ctx context.Context, name string) (uint32, error) {
for _, client := range clients {
if client.Name == name {
return client.Id, nil
}
}
return 0, errors.New("client not found")
}
func GetClientFromId(id uint32) (*Client, error) {
client, ok := clients[id]
if !ok { if !ok {
return nil, errors.New("no such user") return nil, fmt.Errorf("client %d not found", id)
} }
return client, nil return client, nil
} }
func ConnectClientToGroups(ctx context.Context, client *Client) { func CacheSetClient(id uint32, client *Client) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
for _, groupIn := range client.Groups {
chatGroups[groupIn.Id].Members[client.Id] = client CacheClients[id] = client
}
} }
func DisconnectClientFromGroups(ctx context.Context, client *Client) { func CacheDeleteClient(id uint32) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
for _, groupIn := range client.Groups {
delete(chatGroups[groupIn.Id].Members, client.Id) delete(CacheClients, id)
}
} }
+48 -37
View File
@@ -3,66 +3,77 @@ package main
import ( import (
"context" "context"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool"
) )
var dbConn *pgx.Conn var dbConn *pgxpool.Pool
func InitDatabase(ctx context.Context) { func DbInit(ctx context.Context) {
conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432") var err error
dbConn, err = pgxpool.New(ctx, "postgres://master:secret@localhost:5432") // TODO change to env in production
if err != nil { if err != nil {
panic(err) panic(err)
} }
_, err = conn.Exec(ctx, ` _, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS client (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
name VARCHAR(20) UNIQUE NOT NULL, name VARCHAR(20) UNIQUE NOT NULL,
pass_hash VARCHAR(60) NOT NULL, pass_hash VARCHAR(60) NOT NULL,
pronouns VARCHAR(15) DEFAULT NULL, pronouns VARCHAR(15) DEFAULT NULL,
color VARCHAR(3) DEFAULT NULL color_red SMALLINT DEFAULT NULL,
); color_green SMALLINT DEFAULT NULL,
CREATE TABLE IF NOT EXISTS chat_groups ( color_blue SMALLINT DEFAULT NULL,
id SERIAL PRIMARY KEY,
name VARCHAR(48) NOT NULL,
creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
enable_user_colors BOOLEAN NOT NULL DEFAULT true,
group_color VARCHAR(3),
created_at TIMESTAMP NOT NULL DEFAULT NOW() created_at TIMESTAMP NOT NULL DEFAULT NOW()
); )
CREATE TABLE IF NOT EXISTS chat_group_members (
group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (group_id, user_id)
);
`) `)
if err != nil { if err != nil {
panic(err) panic(err)
} }
dbConn = conn _, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS chat_groups (
id SERIAL PRIMARY KEY,
name VARCHAR(48) NOT NULL,
creator_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
owner_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
enable_user_colors BOOLEAN NOT NULL DEFAULT true,
color_red SMALLINT DEFAULT NULL,
color_green SMALLINT DEFAULT NULL,
color_blue SMALLINT DEFAULT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW()
)
`)
if err != nil {
panic(err)
}
_, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS chat_group_members (
group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (group_id, user_id)
)
`)
if err != nil {
panic(err)
}
} }
func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error { func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error {
var id uint64 err := dbConn.QueryRow(ctx, `
var err error INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
err = dbConn.QueryRow(ctx, `
INSERT INTO users (name, pass_hash, pronouns)
VALUES ($1, $2, $3)
RETURNING id RETURNING id
`, client.Name, client.PasswordHash, client.Pronouns).Scan(&id) `, client.Name, client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt).
Scan(&client.Id)
return err return err
} }
func DbGetClient(ctx context.Context, id uint32, client *Client) error { func DbGetClientByName(ctx context.Context, client *Client) error {
err := dbConn.QueryRow(ctx, ` err := dbConn.QueryRow(ctx, `
SELECT name, pass_hash, pronouns, color WHERE id = $1 SELECT name, pass_hash, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1
`, id).Scan(client.Name, client.PasswordHash, client.Pronouns, client.Color) `, client.Name).Scan(&client.Name, &client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt)
if err != nil { return err
return err
}
return nil
} }
+6
View File
@@ -0,0 +1,6 @@
package main
const (
MaxGroupsForClient uint8 = 8
MaxClientsInGroup uint8 = 12
)
BIN
View File
Binary file not shown.
+43 -29
View File
@@ -1,6 +1,12 @@
package main package main
import "net/http" import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool { func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool {
if request.Method != http.MethodPost { if request.Method != http.MethodPost {
@@ -10,7 +16,23 @@ func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool
return true return true
} }
func RegisterHandler(response http.ResponseWriter, request *http.Request) { func parseRgb(str string) ([3]uint8, error) {
parts := strings.SplitN(str, ",", 4)
if len(parts) != 3 {
return [3]uint8{}, fmt.Errorf("invalid rgb")
}
var rgb [3]uint8
for i, p := range parts {
n, err := strconv.ParseUint(strings.TrimSpace(p), 10, 8)
if err != nil {
return [3]uint8{}, fmt.Errorf("invalid component %d: %w", i, err)
}
rgb[i] = uint8(n)
}
return rgb, nil
}
func HttpHandleNewUser(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) { if !isMethodAllowed(&response, request) {
return return
} }
@@ -29,21 +51,32 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) {
return return
} }
newClient := Client{ color, err := parseRgb(request.FormValue("color"))
if err != nil {
http.Error(response, "bad color", http.StatusBadRequest)
}
hashedPassword, err := PasswordHash(password)
if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError)
return
}
newClient := &Client{
Name: username, Name: username,
PasswordHash: password, PasswordHash: hashedPassword,
Color: [3]uint8{120, 120, 120}, //"xxx" Color: color,
CreatedAt: time.Now(),
} }
err := CreateClient(ctx, &newClient) err = DbSaveClientWithoutGroups(ctx, newClient)
if err != nil { if err != nil {
http.Error(response, "taken", http.StatusBadRequest) http.Error(response, "name taken", http.StatusInternalServerError)
return return
} }
response.Write([]byte("registered"))
} }
func LoginHandler(response http.ResponseWriter, request *http.Request) { func HttpHandleLogin(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) { if !isMethodAllowed(&response, request) {
return return
} }
@@ -62,25 +95,6 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) {
return return
} }
var client Client _, err := CacheGetClientById()
id, err := GetIdFromClientName(ctx, username)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
return
}
err = DbGetClient(ctx)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
return
}
token, err := GetToken(id)
if err != nil {
http.Error(response, "Internal error", http.StatusInternalServerError)
return
}
response.Write([]byte(token))
} }
-14
View File
@@ -1,19 +1,5 @@
package main package main
import (
"context"
"log"
"net/http"
)
func main() { func main() {
ctx := context.Background()
InitDatabase(ctx)
http.HandleFunc("/ws", ServeWsConnection)
http.HandleFunc("/new/client", RegisterHandler)
http.HandleFunc("/new/token", LoginHandler)
log.Println("listening on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
} }
+13
View File
@@ -0,0 +1,13 @@
package main
import "golang.org/x/crypto/bcrypt"
func PasswordHash(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func PasswordCheckAgainstHash(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
-21
View File
@@ -1,21 +0,0 @@
package main
import "github.com/coder/websocket"
type Client struct {
PasswordHash string
Name string
Pronouns string
Groups [12]*ChatGroup
Connection *websocket.Conn
Id uint32
Color [3]byte
IsAuthenticated bool
}
type ChatGroup struct {
Name string
Id uint32
Members map[uint32]*Client
Color [3]byte
}
+29
View File
@@ -0,0 +1,29 @@
package main
import (
"time"
"github.com/coder/websocket"
)
type Client struct {
Name string
Pronouns string
PasswordHash string
CreatedAt time.Time
WsConn *websocket.Conn
Id uint32
Groups [MaxGroupsForClient]uint32
Color [3]uint8
}
type Group struct {
Name string
CreatedAt time.Time
Id uint32
CreatorId uint32
OwnerId uint32
Clients [MaxClientsInGroup]uint32
Color [3]uint8
EnableUserColors bool
}
+32 -20
View File
@@ -1,44 +1,56 @@
package main package main
import ( import (
"fmt"
"strconv" "strconv"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
var secretKey = []byte("replace-with-env-variable") const tokenSecret = "tmp" // TODO delete in production
const tokenExpiration = time.Hour
func GetToken(clientId uint32) (string, error) { func TokenCreate(clientId uint32) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, now := time.Now()
jwt.RegisteredClaims{ signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
Subject: strconv.Itoa(int(clientId)), Subject: strconv.FormatUint(uint64(clientId), 10),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(now.Add(tokenExpiration)),
}, }).SignedString([]byte(tokenSecret))
) return signedToken, err
return token.SignedString(secretKey)
} }
func GetClientIdFromToken(token string) (uint32, error) { func TokenValidateGetId(tokenString string) (uint32, error) {
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(t *jwt.Token) (any, error) { token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
} }
return secretKey, nil return []byte(tokenSecret), nil
}) })
if err != nil { if err != nil {
return 0, err return 0, err
} }
claims, ok := parsed.Claims.(*jwt.RegisteredClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !parsed.Valid { if !ok || !token.Valid {
return 0, jwt.ErrTokenInvalidClaims return 0, fmt.Errorf("invalid token")
} }
id, err := strconv.ParseUint(claims.Subject, 10, 32) exp, ok := claims["exp"].(float64)
if err != nil { if !ok || time.Now().Unix() > int64(exp) {
return 0, err return 0, fmt.Errorf("token expired")
} }
sub, ok := claims["sub"].(string)
if !ok {
return 0, fmt.Errorf("invalid subject claim")
}
id, err := strconv.ParseUint(sub, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid subject claim")
}
return uint32(id), nil return uint32(id), nil
} }
+56 -46
View File
@@ -23,85 +23,95 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel() defer cancel()
client := Client{} var client = Client{WsConn: connection}
var isAuthenticated bool
defer closeConnection(&client)
for { for {
var clientMessage map[string]any var clientMessage map[string]any
err := wsjson.Read(ctx, connection, &clientMessage) err := wsjson.Read(ctx, connection, &clientMessage)
if err != nil { if err != nil {
log.Printf("read error: %clientMessage", err) log.Printf("read error: %v", err)
return return
} }
if len(clientMessage) > 0 { if len(clientMessage) > 0 {
if client.IsAuthenticated { if isAuthenticated {
handleAuthenticatedMessage(connection, &client, &clientMessage) if !handleAuthenticatedMessage(&client, &clientMessage) {
} else {
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
closeConnection(connection)
return return
} }
} else {
if !handleUnauthenticatedMessage(&client, &clientMessage) {
return
}
isAuthenticated = true
} }
} }
} }
} }
func sendMessageCloseIfTimeout(conn *websocket.Conn, message *map[string]any) { func sendMessageCloseIfTimeout(client *Client, message *map[string]any) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
err := wsjson.Write(ctx, conn, message) err := wsjson.Write(ctx, client.WsConn, message)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
closeConnection(conn) closeConnection(client)
} }
log.Printf("write error: %v", err) log.Printf("write error: %v", err)
} }
} }
func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) bool { func sendToAllMessageCloseIfTimeout(message *map[string]any) {
token, ok := (*message)["token"].(string) mu.RLock()
if !ok { defer mu.RUnlock()
var errmsg = map[string]any{ for _, client := range CacheClients {
"type": WsServerResponse(BadMessage), sendMessageCloseIfTimeout(client, message)
"message": "token required",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
} }
}
clientId, err := GetClientIdFromToken(token) func handleAuthenticatedMessage(client *Client, clientMessage *map[string]any) bool {
if err != nil { sendToAllMessageCloseIfTimeout(clientMessage)
var errmsg = map[string]any{
"type": WsServerResponse(InvalidCredentials),
"message": "bad token",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
}
client, err = GetClientFromId(clientId)
if err != nil {
var errmsg = map[string]any{
"type": WsServerResponse(InvalidCredentials),
"message": "bad token",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
}
client.IsAuthenticated = true
return true return true
} }
func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) { func handleUnauthenticatedMessage(client *Client, clientMessage *map[string]any) bool {
for _, sendTo := range clients { token, ok := (*clientMessage)["token"].(string)
if sendTo.IsAuthenticated && sendTo.Id != client.Id { if !ok {
sendMessageCloseIfTimeout(conn, message) var msg = map[string]any{
"from": "server",
"error": "no token in message",
} }
sendMessageCloseIfTimeout(client, &msg)
return false
} }
clientId, err := TokenValidateGetId(token)
if err != nil {
var msg = map[string]any{
"from": "server",
"error": "invalid token",
}
sendMessageCloseIfTimeout(client, &msg)
return false
}
clientFromCache, err := CacheGetClientById(clientId)
if err != nil {
var msg = map[string]any{
"from": "server",
"error": "invalid token",
}
sendMessageCloseIfTimeout(client, &msg)
return false
}
*client = *clientFromCache
return true
} }
func closeConnection(conn *websocket.Conn) { func closeConnection(client *Client) {
conn.Close(websocket.StatusNormalClosure, "closing connection") CacheDeleteClient(client.Id)
client.WsConn.CloseNow()
} }