Add saving user, add nonpersistant register login handling
This commit is contained in:
@@ -1,8 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -11,32 +14,48 @@ var (
|
|||||||
chatGroups = make(map[uint32]ChatGroup)
|
chatGroups = make(map[uint32]ChatGroup)
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddGroupToCache(chatGroup *ChatGroup) {
|
func CreateGroup(ctx context.Context, chatGroup *ChatGroup) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
chatGroups[chatGroup.Id] = *chatGroup
|
chatGroups[chatGroup.Id] = *chatGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoveGroupFromCache(chatGroup *ChatGroup) {
|
func DeleteGroup(ctx context.Context, chatGroup *ChatGroup) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
delete(chatGroups, chatGroup.Id)
|
delete(chatGroups, chatGroup.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddClientConnectionsToCache(client *Client) {
|
func CreateClient(ctx context.Context, client *Client) error {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
for _, groupIn := range client.Groups {
|
err := SaveClientWithoutGroups(ctx, client)
|
||||||
chatGroups[groupIn.Id].Members[client.Id] = client
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoveClientConnectionsToCache(client *Client) {
|
func CheckPassword(ctx context.Context, id uint32, password string) error {
|
||||||
mu.Lock()
|
client, err := GetClientFromId(id)
|
||||||
defer mu.Unlock()
|
if err != nil {
|
||||||
for _, groupIn := range client.Groups {
|
return err
|
||||||
delete(chatGroups[groupIn.Id].Members, client.Id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(client.PasswordHash), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetIdFromClientName(ctx context.Context, name string) (uint32, error) {
|
||||||
|
for _, client := range clients {
|
||||||
|
if client.Name == name {
|
||||||
|
return client.Id, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, errors.New("client not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientFromId(id uint32) (*Client, error) {
|
func GetClientFromId(id uint32) (*Client, error) {
|
||||||
@@ -46,3 +65,19 @@ func GetClientFromId(id uint32) (*Client, error) {
|
|||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConnectClientToGroups(ctx context.Context, client *Client) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
for _, groupIn := range client.Groups {
|
||||||
|
chatGroups[groupIn.Id].Members[client.Id] = client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DisconnectClientFromGroups(ctx context.Context, client *Client) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
for _, groupIn := range client.Groups {
|
||||||
|
delete(chatGroups[groupIn.Id].Members, client.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+23
-2
@@ -3,9 +3,8 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
//"golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dbConn *pgx.Conn
|
var dbConn *pgx.Conn
|
||||||
@@ -21,6 +20,7 @@ func InitDatabase(ctx context.Context) {
|
|||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
name VARCHAR(20) UNIQUE NOT NULL,
|
name VARCHAR(20) UNIQUE NOT NULL,
|
||||||
pass_hash VARCHAR(60) NOT NULL,
|
pass_hash VARCHAR(60) NOT NULL,
|
||||||
|
pronouns VARCHAR(15) DEFAULT NULL,
|
||||||
color VARCHAR(3) DEFAULT NULL
|
color VARCHAR(3) DEFAULT NULL
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS chat_groups (
|
CREATE TABLE IF NOT EXISTS chat_groups (
|
||||||
@@ -45,3 +45,24 @@ func InitDatabase(ctx context.Context) {
|
|||||||
|
|
||||||
dbConn = conn
|
dbConn = conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveClientWithoutGroups(ctx context.Context, client *Client) error {
|
||||||
|
var id uint64
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var hashed []byte
|
||||||
|
hashed, err = bcrypt.GenerateFromPassword([]byte(client.PasswordHash), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
password := string(hashed)
|
||||||
|
|
||||||
|
c := string(client.Color[:])
|
||||||
|
err = dbConn.QueryRow(ctx, `
|
||||||
|
INSERT INTO users (name, pass_hash, pronouns, color)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
RETURNING id
|
||||||
|
`, client.Name, password, client.Pronouns, c).Scan(&id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
type WSServerResponse uint8
|
type WsServerResponse uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
BadMessage WSServerResponse = iota
|
BadMessage WsServerResponse = iota
|
||||||
InvalidCredentials
|
InvalidCredentials
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,4 +16,65 @@ func RegisterHandler(response http.ResponseWriter, request *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := request.Context()
|
ctx := request.Context()
|
||||||
|
|
||||||
|
username := request.FormValue("username")
|
||||||
|
if len(username) < 4 {
|
||||||
|
http.Error(response, "no or short username", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
password := request.FormValue("password")
|
||||||
|
if len(password) < 8 {
|
||||||
|
http.Error(response, "no or short password", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newClient := Client{
|
||||||
|
Name: username,
|
||||||
|
PasswordHash: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := CreateClient(ctx, &newClient)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "taken", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Write([]byte("registered"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginHandler(response http.ResponseWriter, request *http.Request) {
|
||||||
|
if !isMethodAllowed(&response, request) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := request.Context()
|
||||||
|
|
||||||
|
username := request.FormValue("username")
|
||||||
|
if len(username) < 4 {
|
||||||
|
http.Error(response, "no or short username", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
password := request.FormValue("password")
|
||||||
|
if len(password) < 8 {
|
||||||
|
http.Error(response, "no or short password", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := GetIdFromClientName(ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "bad login", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = CheckPassword(ctx, id, password)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "bad login", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := GetToken(id)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(response, "Internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Write([]byte(token))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,19 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
func main() {
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
InitDatabase(ctx)
|
||||||
|
|
||||||
|
http.HandleFunc("/ws", ServeWsConnection)
|
||||||
|
http.HandleFunc("/register", RegisterHandler)
|
||||||
|
http.HandleFunc("/login", LoginHandler)
|
||||||
|
|
||||||
|
log.Println("listening on :8080")
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package main
|
|||||||
import "github.com/coder/websocket"
|
import "github.com/coder/websocket"
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Password string
|
PasswordHash string
|
||||||
Name string
|
Name string
|
||||||
Pronouns string
|
Pronouns string
|
||||||
Groups [12]*ChatGroup
|
Groups [12]*ChatGroup
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ import (
|
|||||||
|
|
||||||
var secretKey = []byte("replace-with-env-variable")
|
var secretKey = []byte("replace-with-env-variable")
|
||||||
|
|
||||||
func GetToken(client *Client) (string, error) {
|
func GetToken(clientId uint32) (string, error) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
jwt.RegisteredClaims{
|
jwt.RegisteredClaims{
|
||||||
Subject: strconv.Itoa(int(client.Id)),
|
Subject: strconv.Itoa(int(clientId)),
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
},
|
},
|
||||||
|
|||||||
+11
-7
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/coder/websocket/wsjson"
|
"github.com/coder/websocket/wsjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ServeConnection(responseWriter http.ResponseWriter, request *http.Request) {
|
func ServeWsConnection(responseWriter http.ResponseWriter, request *http.Request) {
|
||||||
connection, err := websocket.Accept(responseWriter, request, nil)
|
connection, err := websocket.Accept(responseWriter, request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("websocket accept error: %v", err)
|
log.Printf("websocket accept error: %v", err)
|
||||||
@@ -33,7 +33,7 @@ func ServeConnection(responseWriter http.ResponseWriter, request *http.Request)
|
|||||||
|
|
||||||
if len(clientMessage) > 0 {
|
if len(clientMessage) > 0 {
|
||||||
if client.IsAuthenticated {
|
if client.IsAuthenticated {
|
||||||
handleAuthenticatedMessage()
|
handleAuthenticatedMessage(connection, &client, &clientMessage)
|
||||||
} else {
|
} else {
|
||||||
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
|
if !handleUnauthenticatedMessage(connection, &client, &clientMessage) {
|
||||||
closeConnection(connection)
|
closeConnection(connection)
|
||||||
@@ -61,7 +61,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
|||||||
token, ok := (*message)["token"].(string)
|
token, ok := (*message)["token"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
var errmsg = map[string]any{
|
var errmsg = map[string]any{
|
||||||
"type": WSServerResponse(BadMessage),
|
"type": WsServerResponse(BadMessage),
|
||||||
"message": "token required",
|
"message": "token required",
|
||||||
}
|
}
|
||||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||||
@@ -71,7 +71,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
|||||||
clientId, err := GetClientIdFromToken(token)
|
clientId, err := GetClientIdFromToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var errmsg = map[string]any{
|
var errmsg = map[string]any{
|
||||||
"type": WSServerResponse(InvalidCredentials),
|
"type": WsServerResponse(InvalidCredentials),
|
||||||
"message": "bad token",
|
"message": "bad token",
|
||||||
}
|
}
|
||||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||||
@@ -81,7 +81,7 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
|||||||
client, err = GetClientFromId(clientId)
|
client, err = GetClientFromId(clientId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var errmsg = map[string]any{
|
var errmsg = map[string]any{
|
||||||
"type": WSServerResponse(InvalidCredentials),
|
"type": WsServerResponse(InvalidCredentials),
|
||||||
"message": "bad token",
|
"message": "bad token",
|
||||||
}
|
}
|
||||||
sendMessageCloseIfTimeout(conn, &errmsg)
|
sendMessageCloseIfTimeout(conn, &errmsg)
|
||||||
@@ -92,8 +92,12 @@ func handleUnauthenticatedMessage(conn *websocket.Conn, client *Client, message
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAuthenticatedMessage() {
|
func handleAuthenticatedMessage(conn *websocket.Conn, client *Client, message *map[string]any) {
|
||||||
|
for _, sendTo := range clients {
|
||||||
|
if sendTo.IsAuthenticated && sendTo.Id != client.Id {
|
||||||
|
sendMessageCloseIfTimeout(conn, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeConnection(conn *websocket.Conn) {
|
func closeConnection(conn *websocket.Conn) {
|
||||||
|
|||||||
Reference in New Issue
Block a user