started adding token system
This commit is contained in:
+48
-10
@@ -6,6 +6,7 @@ from prompt_toolkit.key_binding import KeyBindings
|
|||||||
from prompt_toolkit.keys import Keys
|
from prompt_toolkit.keys import Keys
|
||||||
|
|
||||||
URI = "ws://localhost:8080/ws"
|
URI = "ws://localhost:8080/ws"
|
||||||
|
RETRY_DELAY = 2
|
||||||
|
|
||||||
bindings = KeyBindings()
|
bindings = KeyBindings()
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ def newline(event):
|
|||||||
|
|
||||||
send_queue: asyncio.Queue = None
|
send_queue: asyncio.Queue = None
|
||||||
loop: asyncio.AbstractEventLoop = None
|
loop: asyncio.AbstractEventLoop = None
|
||||||
|
shutdown_event: asyncio.Event = None
|
||||||
|
|
||||||
async def receiver(ws):
|
async def receiver(ws):
|
||||||
async for message in ws:
|
async for message in ws:
|
||||||
@@ -34,14 +36,47 @@ async def sender(ws):
|
|||||||
async def run():
|
async def run():
|
||||||
global send_queue
|
global send_queue
|
||||||
send_queue = asyncio.Queue()
|
send_queue = asyncio.Queue()
|
||||||
async with websockets.connect(URI) as ws:
|
|
||||||
print(f"Connected to {URI}")
|
input_thread = threading.Thread(target=input_loop, daemon=True)
|
||||||
print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n")
|
input_thread_started = False
|
||||||
recv_task = asyncio.create_task(receiver(ws))
|
|
||||||
send_task = asyncio.create_task(sender(ws))
|
while not shutdown_event.is_set():
|
||||||
input_thread = threading.Thread(target=input_loop, daemon=True)
|
try:
|
||||||
input_thread.start()
|
async with websockets.connect(URI) as ws:
|
||||||
await asyncio.gather(recv_task, send_task)
|
print(f"Connected to {URI}")
|
||||||
|
if not input_thread_started:
|
||||||
|
print("Alt+Enter = newline | Enter = send | Ctrl+C = quit\n")
|
||||||
|
input_thread.start()
|
||||||
|
input_thread_started = True
|
||||||
|
|
||||||
|
recv_task = asyncio.create_task(receiver(ws))
|
||||||
|
send_task = asyncio.create_task(sender(ws))
|
||||||
|
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[recv_task, send_task, shutdown_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in pending:
|
||||||
|
t.cancel()
|
||||||
|
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
# reconnect
|
||||||
|
send_queue = asyncio.Queue()
|
||||||
|
print(f"\n[Disconnected] Reconnecting in {RETRY_DELAY}s...")
|
||||||
|
|
||||||
|
except (OSError, websockets.exceptions.WebSocketException):
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
return
|
||||||
|
print(f"[Waiting for server] Retrying in {RETRY_DELAY}s...", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown_event.wait(), timeout=RETRY_DELAY)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
def input_loop():
|
def input_loop():
|
||||||
session = PromptSession(
|
session = PromptSession(
|
||||||
@@ -54,17 +89,20 @@ def input_loop():
|
|||||||
if text is not None:
|
if text is not None:
|
||||||
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
|
asyncio.run_coroutine_threadsafe(send_queue.put(text), loop)
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
asyncio.run_coroutine_threadsafe(send_queue.put(None), loop)
|
asyncio.run_coroutine_threadsafe(shutdown_event.set(), loop)
|
||||||
break
|
break
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global loop
|
global loop, shutdown_event
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(run())
|
loop.run_until_complete(run())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func main() {
|
|||||||
log.Printf("received: %v\n", msg)
|
log.Printf("received: %v\n", msg)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := wsjson.Write(ctx, conn, map[string]any{"echo": msg}); err != nil {
|
if err := wsjson.Write(ctx, conn, msg); err != nil {
|
||||||
log.Println("write error:", err)
|
log.Println("write error:", err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
_ "fmt"
|
||||||
|
_ "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func hashToken(token string) string {
|
||||||
|
var hash = sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateToken() string {
|
||||||
|
var bytes = make([]byte, 32)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user