new start

This commit is contained in:
2026-03-22 14:26:49 +01:00
parent da5f87d67b
commit d756023952
11 changed files with 242 additions and 237 deletions
+15 -70
View File
@@ -1,92 +1,37 @@
package main
import (
"context"
"errors"
"fmt"
"sync"
"golang.org/x/crypto/bcrypt"
)
var (
mu sync.RWMutex
clients = make(map[uint32]*Client)
chatGroups = make(map[uint32]ChatGroup)
CacheClients = make(map[uint32]*Client)
mu sync.RWMutex
Groups = make(map[uint32]*Group)
)
func CreateGroup(ctx context.Context, chatGroup *ChatGroup) {
mu.Lock()
defer mu.Unlock()
chatGroups[chatGroup.Id] = *chatGroup
}
func CacheGetClientById(id uint32) (*Client, error) {
mu.RLock()
defer mu.RUnlock()
func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) {
mu.Lock()
defer mu.Unlock()
delete(chatGroups, chatGroup.Id)
}
func CreateClient(ctx context.Context, client *Client) error {
mu.Lock()
defer mu.Unlock()
clients[client.Id] = client
hashed, err := bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost)
if err != nil {
return err
}
client.PasswordHash = string(hashed)
//err := DbSaveClientWithoutGroups(ctx, client)
//if err != nil {
// return err
//}
return nil
}
func CheckPassword(ctx context.Context, id uint32, password string) error {
client, err := GetClientFromId(id)
if err != nil {
return err
}
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password))
if err != nil {
return err
}
return nil
}
func GetIdFromClientName(ctx context.Context, name string) (uint32, error) {
for _, client := range clients {
if client.Name == name {
return client.Id, nil
}
}
return 0, errors.New("client not found")
}
func GetClientFromId(id uint32) (*Client, error) {
client, ok := clients[id]
client, ok := CacheClients[id]
if !ok {
return nil, errors.New("no such user")
return nil, fmt.Errorf("client %d not found", id)
}
return client, nil
}
func ConnectClientToGroups(ctx context.Context, client *Client) {
func CacheSetClient(id uint32, client *Client) {
mu.Lock()
defer mu.Unlock()
for _, groupIn := range client.Groups {
chatGroups[groupIn.Id].Members[client.Id] = client
}
CacheClients[id] = client
}
func DisconnectClientFromGroups(ctx context.Context, client *Client) {
func CacheDeleteClient(id uint32) {
mu.Lock()
defer mu.Unlock()
for _, groupIn := range client.Groups {
delete(chatGroups[groupIn.Id].Members, client.Id)
}
delete(CacheClients, id)
}
+48 -37
View File
@@ -3,66 +3,77 @@ package main
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
var dbConn *pgx.Conn
var dbConn *pgxpool.Pool
func InitDatabase(ctx context.Context) {
conn, err := pgx.Connect(ctx, "postgres://master:secret@localhost:5432")
func DbInit(ctx context.Context) {
var err error
dbConn, err = pgxpool.New(ctx, "postgres://master:secret@localhost:5432") // TODO change to env in production
if err != nil {
panic(err)
}
_, err = conn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS users (
_, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS client (
id SERIAL PRIMARY KEY,
name VARCHAR(20) UNIQUE NOT NULL,
pass_hash VARCHAR(60) NOT NULL,
pronouns VARCHAR(15) DEFAULT NULL,
color VARCHAR(3) DEFAULT NULL
);
CREATE TABLE IF NOT EXISTS chat_groups (
id SERIAL PRIMARY KEY,
name VARCHAR(48) NOT NULL,
creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
owner_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
enable_user_colors BOOLEAN NOT NULL DEFAULT true,
group_color VARCHAR(3),
color_red SMALLINT DEFAULT NULL,
color_green SMALLINT DEFAULT NULL,
color_blue SMALLINT DEFAULT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS chat_group_members (
group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (group_id, user_id)
);
)
`)
if err != nil {
panic(err)
}
dbConn = conn
_, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS chat_groups (
id SERIAL PRIMARY KEY,
name VARCHAR(48) NOT NULL,
creator_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
owner_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
enable_user_colors BOOLEAN NOT NULL DEFAULT true,
color_red SMALLINT DEFAULT NULL,
color_green SMALLINT DEFAULT NULL,
color_blue SMALLINT DEFAULT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW()
)
`)
if err != nil {
panic(err)
}
_, err = dbConn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS chat_group_members (
group_id INTEGER NOT NULL REFERENCES chat_groups(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES client(id) ON DELETE CASCADE,
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (group_id, user_id)
)
`)
if err != nil {
panic(err)
}
}
func DbSaveClientWithoutGroups(ctx context.Context, client *Client) error {
var id uint64
var err error
err = dbConn.QueryRow(ctx, `
INSERT INTO users (name, pass_hash, pronouns)
VALUES ($1, $2, $3)
err := dbConn.QueryRow(ctx, `
INSERT INTO clients (name, pass_hash, pronouns, color_red, color_green, color_blue, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id
`, client.Name, client.PasswordHash, client.Pronouns).Scan(&id)
`, client.Name, client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt).
Scan(&client.Id)
return err
}
func DbGetClient(ctx context.Context, id uint32, client *Client) error {
func DbGetClientByName(ctx context.Context, client *Client) error {
err := dbConn.QueryRow(ctx, `
SELECT name, pass_hash, pronouns, color WHERE id = $1
`, id).Scan(client.Name, client.PasswordHash, client.Pronouns, client.Color)
if err != nil {
return err
}
return nil
SELECT name, pass_hash, color_red, color_green, color_blue, created_at FROM clients WHERE name = $1
`, client.Name).Scan(&client.Name, &client.PasswordHash, client.Pronouns, client.Color[0], client.Color[1], client.Color[2], client.CreatedAt)
return err
}
+6
View File
@@ -0,0 +1,6 @@
package main
const (
MaxGroupsForClient uint8 = 8
MaxClientsInGroup uint8 = 12
)
BIN
View File
Binary file not shown.
+43 -29
View File
@@ -1,6 +1,12 @@
package main
import "net/http"
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool {
if request.Method != http.MethodPost {
@@ -10,7 +16,23 @@ func isMethodAllowed(response *http.ResponseWriter, request *http.Request) bool
return true
}
func RegisterHandler(response http.ResponseWriter, request *http.Request) {
func parseRgb(str string) ([3]uint8, error) {
parts := strings.SplitN(str, ",", 4)
if len(parts) != 3 {
return [3]uint8{}, fmt.Errorf("invalid rgb")
}
var rgb [3]uint8
for i, p := range parts {
n, err := strconv.ParseUint(strings.TrimSpace(p), 10, 8)
if err != nil {
return [3]uint8{}, fmt.Errorf("invalid component %d: %w", i, err)
}
rgb[i] = uint8(n)
}
return rgb, nil
}
func HttpHandleNewUser(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
}
@@ -29,21 +51,32 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) {
return
}
newClient := Client{
color, err := parseRgb(request.FormValue("color"))
if err != nil {
http.Error(response, "bad color", http.StatusBadRequest)
}
hashedPassword, err := PasswordHash(password)
if err != nil {
http.Error(response, "internal server error", http.StatusInternalServerError)
return
}
newClient := &Client{
Name: username,
PasswordHash: password,
Color: [3]uint8{120, 120, 120}, //"xxx"
PasswordHash: hashedPassword,
Color: color,
CreatedAt: time.Now(),
}
err := CreateClient(ctx, &newClient)
err = DbSaveClientWithoutGroups(ctx, newClient)
if err != nil {
http.Error(response, "taken", http.StatusBadRequest)
http.Error(response, "name taken", http.StatusInternalServerError)
return
}
response.Write([]byte("registered"))
}
func LoginHandler(response http.ResponseWriter, request *http.Request) {
func HttpHandleLogin(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
}
@@ -62,25 +95,6 @@ func LoginHandler(response http.ResponseWriter, request *http.Request) {
return
}
var client Client
_, err := CacheGetClientById()
id, err := GetIdFromClientName(ctx, username)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
return
}
err = DbGetClient(ctx)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
return
}
token, err := GetToken(id)
if err != nil {
http.Error(response, "Internal error", http.StatusInternalServerError)
return
}
response.Write([]byte(token))
}
-14
View File
@@ -1,19 +1,5 @@
package main
import (
"context"
"log"
"net/http"
)
func main() {
ctx := context.Background()
InitDatabase(ctx)
http.HandleFunc("/ws", ServeWsConnection)
http.HandleFunc("/new/client", RegisterHandler)
http.HandleFunc("/new/token", LoginHandler)
log.Println("listening on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
+13
View File
@@ -0,0 +1,13 @@
package main
import "golang.org/x/crypto/bcrypt"
func PasswordHash(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func PasswordCheckAgainstHash(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
-21
View File
@@ -1,21 +0,0 @@
package main
import "github.com/coder/websocket"
type Client struct {
PasswordHash string
Name string
Pronouns string
Groups [12]*ChatGroup
Connection *websocket.Conn
Id uint32
Color [3]byte
IsAuthenticated bool
}
type ChatGroup struct {
Name string
Id uint32
Members map[uint32]*Client
Color [3]byte
}
+29
View File
@@ -0,0 +1,29 @@
package main
import (
"time"
"github.com/coder/websocket"
)
type Client struct {
Name string
Pronouns string
PasswordHash string
CreatedAt time.Time
WsConn *websocket.Conn
Id uint32
Groups [MaxGroupsForClient]uint32
Color [3]uint8
}
type Group struct {
Name string
CreatedAt time.Time
Id uint32
CreatorId uint32
OwnerId uint32
Clients [MaxClientsInGroup]uint32
Color [3]uint8
EnableUserColors bool
}
+32 -20
View File
@@ -1,44 +1,56 @@
package main
import (
"fmt"
"strconv"
"time"
"github.com/golang-jwt/jwt/v5"
)
var secretKey = []byte("replace-with-env-variable")
const tokenSecret = "tmp" // TODO delete in production
const tokenExpiration = time.Hour
func GetToken(clientId uint32) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.RegisteredClaims{
Subject: strconv.Itoa(int(clientId)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
)
return token.SignedString(secretKey)
func TokenCreate(clientId uint32) (string, error) {
now := time.Now()
signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
Subject: strconv.FormatUint(uint64(clientId), 10),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(tokenExpiration)),
}).SignedString([]byte(tokenSecret))
return signedToken, err
}
func GetClientIdFromToken(token string) (uint32, error) {
parsed, err := jwt.ParseWithClaims(token, &jwt.RegisteredClaims{}, func(t *jwt.Token) (any, error) {
func TokenValidateGetId(tokenString string) (uint32, error) {
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secretKey, nil
return []byte(tokenSecret), nil
})
if err != nil {
return 0, err
}
claims, ok := parsed.Claims.(*jwt.RegisteredClaims)
if !ok || !parsed.Valid {
return 0, jwt.ErrTokenInvalidClaims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return 0, fmt.Errorf("invalid token")
}
id, err := strconv.ParseUint(claims.Subject, 10, 32)
if err != nil {
return 0, err
exp, ok := claims["exp"].(float64)
if !ok || time.Now().Unix() > int64(exp) {
return 0, fmt.Errorf("token expired")
}
sub, ok := claims["sub"].(string)
if !ok {
return 0, fmt.Errorf("invalid subject claim")
}
id, err := strconv.ParseUint(sub, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid subject claim")
}
return uint32(id), nil
}
+56 -46
View File
@@ -23,85 +23,95 @@ func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
client := Client{}
var client = Client{WsConn: connection}
var isAuthenticated bool
defer closeConnection(&client)
for {
var clientMessage map[string]any
err := wsjson.Read(ctx, connection, &clientMessage)
if err != nil {
log.Printf("read error: %clientMessage", err)
log.Printf("read error: %v", err)
return
}
if len(clientMessage) > 0 {
if client.IsAuthenticated {
handleAuthenticatedMessage(connection, &client, &clientMessage)
} else {
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
closeConnection(connection)
if isAuthenticated {
if !handleAuthenticatedMessage(&client, &clientMessage) {
return
}
} else {
if !handleUnauthenticatedMessage(&client, &clientMessage) {
return
}
isAuthenticated = true
}
}
}
}
func sendMessageCloseIfTimeout(conn *websocket.Conn, message *map[string]any) {
func sendMessageCloseIfTimeout(client *Client, message *map[string]any) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
err := wsjson.Write(ctx, conn, message)
err := wsjson.Write(ctx, client.WsConn, message)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
closeConnection(conn)
closeConnection(client)
}
log.Printf("write error: %v", err)
}
}
func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) bool {
token, ok := (*message)["token"].(string)
if !ok {
var errmsg = map[string]any{
"type": WsServerResponse(BadMessage),
"message": "token required",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
func sendToAllMessageCloseIfTimeout(message *map[string]any) {
mu.RLock()
defer mu.RUnlock()
for _, client := range CacheClients {
sendMessageCloseIfTimeout(client, message)
}
}
clientId, err := GetClientIdFromToken(token)
if err != nil {
var errmsg = map[string]any{
"type": WsServerResponse(InvalidCredentials),
"message": "bad token",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
}
client, err = GetClientFromId(clientId)
if err != nil {
var errmsg = map[string]any{
"type": WsServerResponse(InvalidCredentials),
"message": "bad token",
}
sendMessageCloseIfTimeout(conn, &errmsg)
return false
}
client.IsAuthenticated = true
func handleAuthenticatedMessage(client *Client, clientMessage *map[string]any) bool {
sendToAllMessageCloseIfTimeout(clientMessage)
return true
}
func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) {
for _, sendTo := range clients {
if sendTo.IsAuthenticated && sendTo.Id != client.Id {
sendMessageCloseIfTimeout(conn, message)
func handleUnauthenticatedMessage(client *Client, clientMessage *map[string]any) bool {
token, ok := (*clientMessage)["token"].(string)
if !ok {
var msg = map[string]any{
"from": "server",
"error": "no token in message",
}
sendMessageCloseIfTimeout(client, &msg)
return false
}
clientId, err := TokenValidateGetId(token)
if err != nil {
var msg = map[string]any{
"from": "server",
"error": "invalid token",
}
sendMessageCloseIfTimeout(client, &msg)
return false
}
clientFromCache, err := CacheGetClientById(clientId)
if err != nil {
var msg = map[string]any{
"from": "server",
"error": "invalid token",
}
sendMessageCloseIfTimeout(client, &msg)
return false
}
*client = *clientFromCache
return true
}
func closeConnection(conn *websocket.Conn) {
conn.Close(websocket.StatusNormalClosure, "closing connection")
func closeConnection(client *Client) {
CacheDeleteClient(client.Id)
client.WsConn.CloseNow()
}