Add saving user, add nonpersistant register login handling

This commit is contained in:
2026-03-17 21:40:48 +01:00
parent e496cb0017
commit d3fc2a65d9
9 changed files with 160 additions and 25 deletions
+45 -10
View File
@@ -1,8 +1,11 @@
package main package main
import ( import (
"context"
"errors" "errors"
"sync" "sync"
"golang.org/x/crypto/bcrypt"
) )
var ( var (
@@ -11,32 +14,48 @@ var (
chatGroups = make(map[uint32]ChatGroup) chatGroups = make(map[uint32]ChatGroup)
) )
func AddGroupToCache(chatGroup *ChatGroup) { func CreateGroup(ctx context.Context, chatGroup *ChatGroup) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
chatGroups[chatGroup.Id] = *chatGroup chatGroups[chatGroup.Id] = *chatGroup
} }
func RemoveGroupFromCache(chatGroup *ChatGroup) { func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
delete(chatGroups, chatGroup.Id) delete(chatGroups, chatGroup.Id)
} }
func AddClientConnectionsToCache(client *Client) { func CreateClient(ctx context.Context, client *Client) error {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
for _, groupIn := range client.Groups { err := SaveClientWithoutGroups(ctx, client)
chatGroups[groupIn.Id].Members[client.Id] = client if err != nil {
return err
} }
return nil
} }
func RemoveClientConnectionsToCache(client *Client) { func CheckPassword(ctx context.Context, id uint32, password string) error {
mu.Lock() client, err := GetClientFromId(id)
defer mu.Unlock() if err != nil {
for _, groupIn := range client.Groups { return err
delete(chatGroups[groupIn.Id].Members, client.Id)
} }
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password))
if err != nil {
return err
}
return nil
}
func GetIdFromClientName(ctx context.Context, name string) (uint32, error) {
for _, client := range clients {
if client.Name == name {
return client.Id, nil
}
}
return 0, errors.New("client not found")
} }
func GetClientFromId(id uint32) (*Client, error) { func GetClientFromId(id uint32) (*Client, error) {
@@ -46,3 +65,19 @@ func GetClientFromId(id uint32) (*Client, error) {
} }
return client, nil return client, nil
} }
func ConnectClientToGroups(ctx context.Context, client *Client) {
mu.Lock()
defer mu.Unlock()
for _, groupIn := range client.Groups {
chatGroups[groupIn.Id].Members[client.Id] = client
}
}
func DisconnectClientFromGroups(ctx context.Context, client *Client) {
mu.Lock()
defer mu.Unlock()
for _, groupIn := range client.Groups {
delete(chatGroups[groupIn.Id].Members, client.Id)
}
}
+23 -2
View File
@@ -3,9 +3,8 @@ package main
import ( import (
"context" "context"
//"golang.org/x/crypto/bcrypt"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"golang.org/x/crypto/bcrypt"
) )
var dbConn *pgx.Conn var dbConn *pgx.Conn
@@ -21,6 +20,7 @@ func InitDatabase(ctx context.Context) {
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
name VARCHAR(20) UNIQUE NOT NULL, name VARCHAR(20) UNIQUE NOT NULL,
pass_hash VARCHAR(60) NOT NULL, pass_hash VARCHAR(60) NOT NULL,
pronouns VARCHAR(15) DEFAULT NULL,
color VARCHAR(3) DEFAULT NULL color VARCHAR(3) DEFAULT NULL
); );
CREATE TABLE IF NOT EXISTS chat_groups ( CREATE TABLE IF NOT EXISTS chat_groups (
@@ -45,3 +45,24 @@ func InitDatabase(ctx context.Context) {
dbConn = conn dbConn = conn
} }
func SaveClientWithoutGroups(ctx context.Context, client *Client) error {
var id uint64
var err error
var hashed []byte
hashed, err = bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost)
if err != nil {
return err
}
password := string(hashed)
c := string(client.Color[:])
err = dbConn.QueryRow(ctx, `
INSERT INTO users (name, pass_hash, pronouns, color)
VALUES ($1, $2, $3)
RETURNING id
`, client.Name, password, client.Pronouns, c).Scan(&id)
return err
}
+2 -2
View File
@@ -1,8 +1,8 @@
package main package main
type WSServerResponse uint8 type WsServerResponse uint8
const ( const (
BadMessage WSServerResponse = iota BadMessage WsServerResponse = iota
InvalidCredentials InvalidCredentials
) )
BIN
View File
Binary file not shown.
+61
View File
@@ -16,4 +16,65 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) {
} }
ctx := request.Context() ctx := request.Context()
username := request.FormValue("username")
if len(username) < 4 {
http.Error(response, "no or short username", http.StatusBadRequest)
return
}
password := request.FormValue("password")
if len(password) < 8 {
http.Error(response, "no or short password", http.StatusBadRequest)
return
}
newClient := Client{
Name: username,
PasswordHash: password,
}
err := CreateClient(ctx, &newClient)
if err != nil {
http.Error(response, "taken", http.StatusBadRequest)
return
}
response.Write([]byte("registered"))
}
func LoginHandler(response http.ResponseWriter, request *http.Request) {
if !isMethodAllowed(&response, request) {
return
}
ctx := request.Context()
username := request.FormValue("username")
if len(username) < 4 {
http.Error(response, "no or short username", http.StatusBadRequest)
return
}
password := request.FormValue("password")
if len(password) < 8 {
http.Error(response, "no or short password", http.StatusBadRequest)
return
}
id, err := GetIdFromClientName(ctx, username)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
}
err = CheckPassword(ctx, id, password)
if err != nil {
http.Error(response, "bad login", http.StatusBadRequest)
}
token, err := GetToken(id)
if err != nil {
http.Error(response, "Internal error", http.StatusInternalServerError)
}
response.Write([]byte(token))
} }
+15 -1
View File
@@ -1,5 +1,19 @@
package main package main
func main() { import (
"context"
"log"
"net/http"
)
func main() {
ctx := context.Background()
InitDatabase(ctx)
http.HandleFunc("/ws", ServeWsConnection)
http.HandleFunc("/register", RegisterHandler)
http.HandleFunc("/login", LoginHandler)
log.Println("listening on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
} }
+1 -1
View File
@@ -3,7 +3,7 @@ package main
import "github.com/coder/websocket" import "github.com/coder/websocket"
type Client struct { type Client struct {
Password string PasswordHash string
Name string Name string
Pronouns string Pronouns string
Groups [12]*ChatGroup Groups [12]*ChatGroup
+2 -2
View File
@@ -9,10 +9,10 @@ import (
var secretKey = []byte("replace-with-env-variable") var secretKey = []byte("replace-with-env-variable")
func GetToken(client *Client) (string, error) { func GetToken(clientId uint32) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.RegisteredClaims{ jwt.RegisteredClaims{
Subject: strconv.Itoa(int(client.Id)), Subject: strconv.Itoa(int(clientId)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
}, },
+11 -7
View File
@@ -11,7 +11,7 @@ import (
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
) )
func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) { func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) {
connection, err := websocket.Accept(responseWriter, request, nil) connection, err := websocket.Accept(responseWriter, request, nil)
if err != nil { if err != nil {
log.Printf("websocket accept error: %v", err) log.Printf("websocket accept error: %v", err)
@@ -33,7 +33,7 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request)
if len(clientMessage) > 0 { if len(clientMessage) > 0 {
if client.IsAuthenticated { if client.IsAuthenticated {
handleAuthenticatedMessage() handleAuthenticatedMessage(connection, &client, &clientMessage)
} else { } else {
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) { if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
closeConnection(connection) closeConnection(connection)
@@ -61,7 +61,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
token, ok := (*message)["token"].(string) token, ok := (*message)["token"].(string)
if !ok { if !ok {
var errmsg = map[string]any{ var errmsg = map[string]any{
"type": WSServerResponse(BadMessage), "type": WsServerResponse(BadMessage),
"message": "token required", "message": "token required",
} }
sendMessageCloseIfTimeout(conn, &errmsg) sendMessageCloseIfTimeout(conn, &errmsg)
@@ -71,7 +71,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
clientId, err := GetClientIdFromToken(token) clientId, err := GetClientIdFromToken(token)
if err != nil { if err != nil {
var errmsg = map[string]any{ var errmsg = map[string]any{
"type": WSServerResponse(InvalidCredentials), "type": WsServerResponse(InvalidCredentials),
"message": "bad token", "message": "bad token",
} }
sendMessageCloseIfTimeout(conn, &errmsg) sendMessageCloseIfTimeout(conn, &errmsg)
@@ -81,7 +81,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
client, err = GetClientFromId(clientId) client, err = GetClientFromId(clientId)
if err != nil { if err != nil {
var errmsg = map[string]any{ var errmsg = map[string]any{
"type": WSServerResponse(InvalidCredentials), "type": WsServerResponse(InvalidCredentials),
"message": "bad token", "message": "bad token",
} }
sendMessageCloseIfTimeout(conn, &errmsg) sendMessageCloseIfTimeout(conn, &errmsg)
@@ -92,8 +92,12 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
return true return true
} }
func handleAuthenticatedMessage() { func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) {
for _, sendTo := range clients {
if sendTo.IsAuthenticated && sendTo.Id != client.Id {
sendMessageCloseIfTimeout(conn, message)
}
}
} }
func closeConnection(conn *websocket.Conn) { func closeConnection(conn *websocket.Conn) {