| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- import argparse
- import base64
- import io
- import json
- import queue
- import secrets
- import socket
- import sqlite3
- import sys
- import threading
- import zipfile
- from datetime import datetime
- from flask import (
- Flask,
- make_response,
- redirect,
- request,
- send_file,
- send_from_directory,
- )
- from waitress import serve
- # default port if none provided via CLI
- PORT = 2004
- custom_address = None
- verbose = False
- CHATNAME = "lainchat"
- app = Flask(__name__, static_folder="public", static_url_path="")
- db_lock = threading.Lock()
- def init_db():
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("""
- CREATE TABLE IF NOT EXISTS messages (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- payload TEXT
- )
- """)
- c.execute("""
- CREATE TABLE IF NOT EXISTS users (
- username TEXT PRIMARY KEY,
- session_id TEXT,
- last_active TIMESTAMP
- )
- """)
- c.execute("""
- CREATE TABLE IF NOT EXISTS admins (
- username TEXT PRIMARY KEY,
- password TEXT
- )
- """)
- conn.commit()
- conn.close()
- init_db()
- # each client connection has its own queue
- clients = []
- @app.route("/")
- @app.route("/<path:filename>")
- def static_files(filename="index.html"):
- if "u" in request.args and filename == "index.html":
- username = request.args.get("u")
- session_id = request.cookies.get("session_id")
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
- row = c.fetchone()
- if row and row[0] == session_id:
- c.execute(
- "UPDATE users SET last_active = datetime('now') WHERE username = ?",
- (username,),
- )
- conn.commit()
- conn.close()
- log(f"{request.remote_addr} connected with username: {username}")
- # We can't easily know if they just joined or reloaded, but let's send a welcome
- # Actually, let's just let them load the page
- return send_from_directory("public", filename)
- else:
- conn.close()
- return redirect("/login.html")
- elif filename != "index.html":
- return send_from_directory("public", filename)
- else:
- return redirect("/login.html")
- @app.route("/api/check_username", methods=["GET"])
- def check_username():
- username = request.args.get("u", "").strip()
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
- is_admin = c.fetchone() is not None
- conn.close()
- return {"status": "admin" if is_admin else "user"}
- @app.route("/api/login", methods=["POST"])
- def api_login():
- data = request.json
- username = data.get("u", "").strip()
- password = data.get("password", "")
- if not username:
- return {"error": "Invalid username"}, 400
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("DELETE FROM users WHERE last_active < datetime('now', '-1 hour')")
- c.execute("SELECT password FROM admins WHERE username = ?", (username,))
- admin_row = c.fetchone()
- if admin_row:
- if admin_row[0] != password:
- conn.close()
- return {"error": "Invalid password for admin"}, 401
- else:
- session_id = request.cookies.get("session_id")
- c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
- user_row = c.fetchone()
- if user_row and user_row[0] != session_id:
- conn.close()
- return {"error": "Username reserved by another user"}, 403
- new_session_id = secrets.token_hex(16)
- c.execute(
- "INSERT OR REPLACE INTO users (username, session_id, last_active) VALUES (?, ?, datetime('now'))",
- (username, new_session_id),
- )
- conn.commit()
- conn.close()
- resp = make_response({"success": True})
- resp.set_cookie("session_id", new_session_id)
- send_new_user_message(username)
- return resp
- # POST: receives a message from one client and forwards it to all other connections
- @app.route("/api/messages", methods=["POST"])
- def post_message():
- data = request.json
- username = data.get("from")
- text = data.get("text", "")
- session_id = request.cookies.get("session_id")
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
- row = c.fetchone()
- if not row or row[0] != session_id:
- conn.close()
- return {"error": "Unauthorized"}, 401
- c.execute(
- "UPDATE users SET last_active = datetime('now') WHERE username = ?",
- (username,),
- )
- if text.startswith("/kick "):
- target = text.split(" ")[1].strip()
- c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
- if c.fetchone():
- c.execute("DELETE FROM users WHERE username = ?", (target,))
- conn.commit()
- disconnect_msg = f'{{"type": "system", "content": "{target} was disconnected by an admin."}}'
- for q in clients[:]:
- try:
- q.put(disconnect_msg)
- except:
- pass
- conn.close()
- return "", 204
- conn.commit()
- conn.close()
- payload = request.get_data(as_text=True)
- log(f"Message received by {request.remote_addr}: {payload}")
- # Save to sqlite
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("INSERT INTO messages (payload) VALUES (?)", (payload,))
- # Prune if over 500
- c.execute("SELECT COUNT(*) FROM messages")
- count = c.fetchone()[0]
- if count > 500:
- excess = count - 500
- c.execute(
- "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)",
- (excess,),
- )
- conn.commit()
- conn.close()
- for q in clients[:]:
- try:
- q.put(payload)
- log(
- f"Message from {request.remote_addr} forwarded to {q.qsize()} listener(s)"
- )
- except:
- clients.remove(q)
- return "", 204
- # GET: all clients listen here, with long-polling
- @app.route("/api/messages", methods=["GET"])
- def get_messages():
- session_id = request.cookies.get("session_id")
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
- if not c.fetchone():
- conn.close()
- return {"error": "Unauthorized"}, 401
- conn.close()
- q = queue.Queue()
- clients.append(q)
- try:
- # wait up to 30 seconds for a message
- msg = q.get(timeout=30)
- return msg, 200
- except queue.Empty:
- return "", 204 # no message, client retries
- finally:
- clients.remove(q) # clean up client queue on disconnect
- @app.route("/api/download_images", methods=["GET"])
- def download_images():
- session_id = request.cookies.get("session_id")
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
- if not c.fetchone():
- conn.close()
- return {"error": "Unauthorized"}, 401
- c.execute("SELECT payload FROM messages ORDER BY id DESC")
- rows = c.fetchall()
- conn.close()
- memory_file = io.BytesIO()
- with zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) as zf:
- index = 0
- for row in rows:
- try:
- data = json.loads(row[0])
- doodle = data.get("doodle", "")
- if doodle.startswith("data:image/png;base64,"):
- b64_data = doodle.split(",")[1]
- img_bytes = base64.b64decode(b64_data)
- zf.writestr(f"{index}.png", img_bytes)
- index += 1
- except Exception:
- pass
- memory_file.seek(0)
- return send_file(
- memory_file,
- mimetype="application/zip",
- as_attachment=True,
- download_name=f"{CHATNAME}_archive.zip",
- )
- @app.route("/api/backlog", methods=["GET"])
- def get_backlog():
- log(f"Backlog requested by {request.remote_addr}")
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute("SELECT payload FROM messages ORDER BY id ASC")
- rows = c.fetchall()
- conn.close()
- # return a JSON array of JSON strings
- return "[" + ",".join(row[0] for row in rows) + "]", 200
- @app.route("/api/room/details", methods=["GET"])
- def get_room_details():
- log(f"Room details requested by {request.remote_addr}")
- return {
- "serverIP": custom_address,
- "port": PORT,
- }, 200
- def send_new_user_message(username):
- welcome_message = (
- f'{{"type": "system", "content": "Now entering room: {username}"}}'
- )
- for q in clients[:]:
- try:
- q.put(welcome_message)
- except:
- clients.remove(q)
- def log(msg):
- if verbose:
- timestamp = "[{:%Y-%m-%d %H:%M:%S}]".format(datetime.now())
- print(f"{timestamp}: {msg}")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="run web server")
- parser.add_argument(
- "--port",
- "-p",
- type=int,
- help="port to listen on (default: %(default)s)",
- default=PORT,
- )
- parser.add_argument(
- "--server",
- "-s",
- action="store_true",
- help="run server in headless mode without opening a browser",
- )
- parser.add_argument(
- "--threads",
- "-t",
- type=int,
- help="number of threads to use (default: %(default)s)",
- default=32,
- )
- parser.add_argument(
- "--address",
- "-a",
- type=str,
- help="address displayed to users in browser",
- default="0.0.0.0",
- )
- parser.add_argument(
- "--verbose", "-v", action="store_true", help="enable verbose logging"
- )
- parser.add_argument(
- "--admin",
- "-A",
- action="append",
- help="Declare admin username:password",
- default=[],
- )
- args = parser.parse_args()
- port = args.port or PORT
- open_browser = not args.server
- threads = args.threads
- custom_address = args.address
- verbose = args.verbose
- for admin in args.admin:
- if ":" in admin:
- u, p = admin.split(":", 1)
- with db_lock:
- conn = sqlite3.connect("chat.db", check_same_thread=False)
- c = conn.cursor()
- c.execute(
- "INSERT OR REPLACE INTO admins (username, password) VALUES (?, ?)",
- (u, p),
- )
- conn.commit()
- conn.close()
- print(f"\n{CHATNAME} Server running!")
- print(f" → Local: http://0.0.0.0:{port}")
- if not (1 <= port <= 65535):
- log(f"Error: port {port} is out of range (1-65535)")
- sys.exit(2)
- serve(app, host="0.0.0.0", port=port, threads=threads)
|