Add saving user, add nonpersistant register login handling
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -11,32 +14,48 @@ var (
|
||||
chatGroups = make(map[uint32]ChatGroup)
|
||||
)
|
||||
|
||||
func AddGroupToCache(chatGroup *ChatGroup) {
|
||||
func CreateGroup(ctx context.Context, chatGroup *ChatGroup) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
chatGroups[chatGroup.Id] = *chatGroup
|
||||
}
|
||||
|
||||
func RemoveGroupFromCache(chatGroup *ChatGroup) {
|
||||
func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
delete(chatGroups, chatGroup.Id)
|
||||
}
|
||||
|
||||
func AddClientConnectionsToCache(client *Client) {
|
||||
func CreateClient(ctx context.Context, client *Client) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, groupIn := range client.Groups {
|
||||
chatGroups[groupIn.Id].Members[client.Id] = client
|
||||
err := SaveClientWithoutGroups(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveClientConnectionsToCache(client *Client) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, groupIn := range client.Groups {
|
||||
delete(chatGroups[groupIn.Id].Members, client.Id)
|
||||
func CheckPassword(ctx context.Context, id uint32, password string) error {
|
||||
client, err := GetClientFromId(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetIdFromClientName(ctx context.Context, name string) (uint32, error) {
|
||||
for _, client := range clients {
|
||||
if client.Name == name {
|
||||
return client.Id, nil
|
||||
}
|
||||
}
|
||||
return 0, errors.New("client not found")
|
||||
}
|
||||
|
||||
func GetClientFromId(id uint32) (*Client, error) {
|
||||
@@ -46,3 +65,19 @@ func GetClientFromId(id uint32) (*Client, error) {
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func ConnectClientToGroups(ctx context.Context, client *Client) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, groupIn := range client.Groups {
|
||||
chatGroups[groupIn.Id].Members[client.Id] = client
|
||||
}
|
||||
}
|
||||
|
||||
func DisconnectClientFromGroups(ctx context.Context, client *Client) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, groupIn := range client.Groups {
|
||||
delete(chatGroups[groupIn.Id].Members, client.Id)
|
||||
}
|
||||
}
|
||||
|
||||
+23
-2
@@ -3,9 +3,8 @@ package main
|
||||
import (
|
||||
"context"
|
||||
|
||||
//"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var dbConn *pgx.Conn
|
||||
@@ -21,6 +20,7 @@ func InitDatabase(ctx context.Context) {
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(20) UNIQUE NOT NULL,
|
||||
pass_hash VARCHAR(60) NOT NULL,
|
||||
pronouns VARCHAR(15) DEFAULT NULL,
|
||||
color VARCHAR(3) DEFAULT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS chat_groups (
|
||||
@@ -45,3 +45,24 @@ func InitDatabase(ctx context.Context) {
|
||||
|
||||
dbConn = conn
|
||||
}
|
||||
|
||||
func SaveClientWithoutGroups(ctx context.Context, client *Client) error {
|
||||
var id uint64
|
||||
var err error
|
||||
|
||||
var hashed []byte
|
||||
hashed, err = bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
password := string(hashed)
|
||||
|
||||
c := string(client.Color[:])
|
||||
err = dbConn.QueryRow(ctx, `
|
||||
INSERT INTO users (name, pass_hash, pronouns, color)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id
|
||||
`, client.Name, password, client.Pronouns, c).Scan(&id)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package main
|
||||
|
||||
type WSServerResponse uint8
|
||||
type WsServerResponse uint8
|
||||
|
||||
const (
|
||||
BadMessage WSServerResponse = iota
|
||||
BadMessage WsServerResponse = iota
|
||||
InvalidCredentials
|
||||
)
|
||||
|
||||
@@ -16,4 +16,65 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) {
|
||||
}
|
||||
|
||||
ctx := request.Context()
|
||||
|
||||
username := request.FormValue("username")
|
||||
if len(username) < 4 {
|
||||
http.Error(response, "no or short username", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
password := request.FormValue("password")
|
||||
if len(password) < 8 {
|
||||
http.Error(response, "no or short password", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newClient := Client{
|
||||
Name: username,
|
||||
PasswordHash: password,
|
||||
}
|
||||
|
||||
err := CreateClient(ctx, &newClient)
|
||||
if err != nil {
|
||||
http.Error(response, "taken", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
response.Write([]byte("registered"))
|
||||
}
|
||||
|
||||
func LoginHandler(response http.ResponseWriter, request *http.Request) {
|
||||
if !isMethodAllowed(&response, request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := request.Context()
|
||||
|
||||
username := request.FormValue("username")
|
||||
if len(username) < 4 {
|
||||
http.Error(response, "no or short username", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
password := request.FormValue("password")
|
||||
if len(password) < 8 {
|
||||
http.Error(response, "no or short password", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := GetIdFromClientName(ctx, username)
|
||||
if err != nil {
|
||||
http.Error(response, "bad login", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = CheckPassword(ctx, id, password)
|
||||
if err != nil {
|
||||
http.Error(response, "bad login", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
token, err := GetToken(id)
|
||||
if err != nil {
|
||||
http.Error(response, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
response.Write([]byte(token))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
package main
|
||||
|
||||
func main() {
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
InitDatabase(ctx)
|
||||
|
||||
http.HandleFunc("/ws", ServeWsConnection)
|
||||
http.HandleFunc("/register", RegisterHandler)
|
||||
http.HandleFunc("/login", LoginHandler)
|
||||
|
||||
log.Println("listening on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package main
|
||||
import "github.com/coder/websocket"
|
||||
|
||||
type Client struct {
|
||||
Password string
|
||||
PasswordHash string
|
||||
Name string
|
||||
Pronouns string
|
||||
Groups [12]*ChatGroup
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
|
||||
var secretKey = []byte("replace-with-env-variable")
|
||||
|
||||
func GetToken(client *Client) (string, error) {
|
||||
func GetToken(clientId uint32) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.RegisteredClaims{
|
||||
Subject: strconv.Itoa(int(client.Id)),
|
||||
Subject: strconv.Itoa(int(clientId)),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
|
||||
+11
-7
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) {
|
||||
func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) {
|
||||
connection, err := websocket.Accept(responseWriter, request, nil)
|
||||
if err != nil {
|
||||
log.Printf("websocket accept error: %v", err)
|
||||
@@ -33,7 +33,7 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request)
|
||||
|
||||
if len(clientMessage) > 0 {
|
||||
if client.IsAuthenticated {
|
||||
handleAuthenticatedMessage()
|
||||
handleAuthenticatedMessage(connection, &client, &clientMessage)
|
||||
} else {
|
||||
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
|
||||
closeConnection(connection)
|
||||
@@ -61,7 +61,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
||||
token, ok := (*message)["token"].(string)
|
||||
if !ok {
|
||||
var errmsg = map[string]any{
|
||||
"type": WSServerResponse(BadMessage),
|
||||
"type": WsServerResponse(BadMessage),
|
||||
"message": "token required",
|
||||
}
|
||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||
@@ -71,7 +71,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
||||
clientId, err := GetClientIdFromToken(token)
|
||||
if err != nil {
|
||||
var errmsg = map[string]any{
|
||||
"type": WSServerResponse(InvalidCredentials),
|
||||
"type": WsServerResponse(InvalidCredentials),
|
||||
"message": "bad token",
|
||||
}
|
||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||
@@ -81,7 +81,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
||||
client, err = GetClientFromId(clientId)
|
||||
if err != nil {
|
||||
var errmsg = map[string]any{
|
||||
"type": WSServerResponse(InvalidCredentials),
|
||||
"type": WsServerResponse(InvalidCredentials),
|
||||
"message": "bad token",
|
||||
}
|
||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||
@@ -92,8 +92,12 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
||||
return true
|
||||
}
|
||||
|
||||
func handleAuthenticatedMessage() {
|
||||
|
||||
func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) {
|
||||
for _, sendTo := range clients {
|
||||
if sendTo.IsAuthenticated && sendTo.Id != client.Id {
|
||||
sendMessageCloseIfTimeout(conn, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func closeConnection(conn *websocket.Conn) {
|
||||
|
||||
Reference in New Issue
Block a user