main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import argparse
  2. import base64
  3. import io
  4. import json
  5. import queue
  6. import secrets
  7. import socket
  8. import sqlite3
  9. import sys
  10. import threading
  11. import zipfile
  12. from datetime import datetime
  13. from flask import (
  14. Flask,
  15. make_response,
  16. redirect,
  17. request,
  18. send_file,
  19. send_from_directory,
  20. )
  21. from waitress import serve
  22. # default port if none provided via CLI
  23. PORT = 2004
  24. custom_address = None
  25. verbose = False
  26. CHATNAME = "lainchat"
  27. PRUNE_COUNT = 500
  28. app = Flask(__name__, static_folder="public", static_url_path="")
  29. db_lock = threading.Lock()
  30. def init_db():
  31. with db_lock:
  32. conn = sqlite3.connect("chat.db", check_same_thread=False)
  33. c = conn.cursor()
  34. c.execute("""
  35. CREATE TABLE IF NOT EXISTS messages (
  36. id INTEGER PRIMARY KEY AUTOINCREMENT,
  37. payload TEXT
  38. )
  39. """)
  40. c.execute("""
  41. CREATE TABLE IF NOT EXISTS users (
  42. username TEXT PRIMARY KEY,
  43. session_id TEXT,
  44. last_active TIMESTAMP
  45. )
  46. """)
  47. c.execute("""
  48. CREATE TABLE IF NOT EXISTS admins (
  49. username TEXT PRIMARY KEY,
  50. password TEXT
  51. )
  52. """)
  53. conn.commit()
  54. conn.close()
  55. init_db()
  56. # each client connection has its own queue
  57. clients = []
  58. @app.route("/")
  59. @app.route("/<path:filename>")
  60. def static_files(filename="index.html"):
  61. if request.args.get("stream") == "1" and filename == "index.html":
  62. return send_from_directory("public", filename)
  63. if "u" in request.args and filename == "index.html":
  64. username = request.args.get("u")
  65. session_id = request.cookies.get("session_id")
  66. with db_lock:
  67. conn = sqlite3.connect("chat.db", check_same_thread=False)
  68. c = conn.cursor()
  69. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  70. row = c.fetchone()
  71. if row and row[0] == session_id:
  72. c.execute(
  73. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  74. (username,),
  75. )
  76. conn.commit()
  77. conn.close()
  78. log(f"{request.remote_addr} connected with username: {username}")
  79. # We can't easily know if they just joined or reloaded, but let's send a welcome
  80. # Actually, let's just let them load the page
  81. return send_from_directory("public", filename)
  82. else:
  83. conn.close()
  84. return redirect("/login.html")
  85. elif filename != "index.html":
  86. return send_from_directory("public", filename)
  87. else:
  88. return redirect("/login.html")
  89. @app.route("/api/check_username", methods=["GET"])
  90. def check_username():
  91. username = request.args.get("u", "").strip()
  92. with db_lock:
  93. conn = sqlite3.connect("chat.db", check_same_thread=False)
  94. c = conn.cursor()
  95. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  96. is_admin = c.fetchone() is not None
  97. conn.close()
  98. return {"status": "admin" if is_admin else "user"}
  99. @app.route("/api/login", methods=["POST"])
  100. def api_login():
  101. data = request.json
  102. username = data.get("u", "").strip()
  103. password = data.get("password", "")
  104. if not username:
  105. return {"error": "Invalid username"}, 400
  106. with db_lock:
  107. conn = sqlite3.connect("chat.db", check_same_thread=False)
  108. c = conn.cursor()
  109. c.execute("DELETE FROM users WHERE last_active < datetime('now', '-1 hour')")
  110. c.execute("SELECT password FROM admins WHERE username = ?", (username,))
  111. admin_row = c.fetchone()
  112. if admin_row:
  113. if admin_row[0] != password:
  114. conn.close()
  115. return {"error": "Invalid password for admin"}, 401
  116. else:
  117. session_id = request.cookies.get("session_id")
  118. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  119. user_row = c.fetchone()
  120. if user_row and user_row[0] != session_id:
  121. conn.close()
  122. return {"error": "Username reserved by another user"}, 403
  123. new_session_id = secrets.token_hex(16)
  124. c.execute(
  125. "INSERT OR REPLACE INTO users (username, session_id, last_active) VALUES (?, ?, datetime('now'))",
  126. (username, new_session_id),
  127. )
  128. conn.commit()
  129. conn.close()
  130. resp = make_response({"success": True})
  131. resp.set_cookie("session_id", new_session_id)
  132. send_new_user_message(username)
  133. return resp
  134. # POST: receives a message from one client and forwards it to all other connections
  135. @app.route("/api/messages", methods=["POST"])
  136. def post_message():
  137. data = request.json
  138. username = data.get("from")
  139. text = data.get("text", "")
  140. session_id = request.cookies.get("session_id")
  141. with db_lock:
  142. conn = sqlite3.connect("chat.db", check_same_thread=False)
  143. c = conn.cursor()
  144. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  145. row = c.fetchone()
  146. if not row or row[0] != session_id:
  147. conn.close()
  148. return {"error": "Unauthorized"}, 401
  149. c.execute(
  150. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  151. (username,),
  152. )
  153. if text.startswith("/kick "):
  154. target = text.split(" ")[1].strip()
  155. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  156. if c.fetchone():
  157. c.execute("DELETE FROM users WHERE username = ?", (target,))
  158. conn.commit()
  159. disconnect_msg = f'{{"type": "system", "content": "{target} was disconnected by an admin."}}'
  160. for q in clients[:]:
  161. try:
  162. q.put(disconnect_msg)
  163. except:
  164. pass
  165. conn.close()
  166. return "", 204
  167. conn.commit()
  168. conn.close()
  169. payload = request.get_data(as_text=True)
  170. log(f"Message received by {request.remote_addr}: {payload}")
  171. # Save to sqlite
  172. with db_lock:
  173. conn = sqlite3.connect("chat.db", check_same_thread=False)
  174. c = conn.cursor()
  175. c.execute("INSERT INTO messages (payload) VALUES (?)", (payload,))
  176. # Prune if over prune count
  177. c.execute("SELECT COUNT(*) FROM messages")
  178. count = c.fetchone()[0]
  179. if count > PRUNE_COUNT:
  180. excess = count - PRUNE_COUNT
  181. c.execute(
  182. "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)",
  183. (excess,),
  184. )
  185. conn.commit()
  186. conn.close()
  187. for q in clients[:]:
  188. try:
  189. q.put(payload)
  190. log(
  191. f"Message from {request.remote_addr} forwarded to {q.qsize()} listener(s)"
  192. )
  193. except:
  194. clients.remove(q)
  195. return "", 204
  196. # GET: all clients listen here, with long-polling
  197. @app.route("/api/messages", methods=["GET"])
  198. def get_messages():
  199. is_stream = request.args.get("stream") == "1"
  200. if not is_stream:
  201. session_id = request.cookies.get("session_id")
  202. with db_lock:
  203. conn = sqlite3.connect("chat.db", check_same_thread=False)
  204. c = conn.cursor()
  205. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  206. if not c.fetchone():
  207. conn.close()
  208. return {"error": "Unauthorized"}, 401
  209. conn.close()
  210. q = queue.Queue()
  211. clients.append(q)
  212. try:
  213. # wait up to 30 seconds for a message
  214. msg = q.get(timeout=30)
  215. return msg, 200
  216. except queue.Empty:
  217. return "", 204 # no message, client retries
  218. finally:
  219. clients.remove(q) # clean up client queue on disconnect
  220. @app.route("/api/download_images", methods=["GET"])
  221. def download_images():
  222. session_id = request.cookies.get("session_id")
  223. with db_lock:
  224. conn = sqlite3.connect("chat.db", check_same_thread=False)
  225. c = conn.cursor()
  226. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  227. if not c.fetchone():
  228. conn.close()
  229. return {"error": "Unauthorized"}, 401
  230. c.execute("SELECT payload FROM messages ORDER BY id DESC")
  231. rows = c.fetchall()
  232. conn.close()
  233. memory_file = io.BytesIO()
  234. with zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) as zf:
  235. index = 0
  236. for row in rows:
  237. try:
  238. data = json.loads(row[0])
  239. doodle = data.get("doodle", "")
  240. if doodle.startswith("data:image/png;base64,"):
  241. b64_data = doodle.split(",")[1]
  242. img_bytes = base64.b64decode(b64_data)
  243. zf.writestr(f"{index}.png", img_bytes)
  244. index += 1
  245. except Exception:
  246. pass
  247. memory_file.seek(0)
  248. return send_file(
  249. memory_file,
  250. mimetype="application/zip",
  251. as_attachment=True,
  252. download_name=f"{CHATNAME}_archive.zip",
  253. )
  254. @app.route("/api/backlog", methods=["GET"])
  255. def get_backlog():
  256. is_stream = request.args.get("stream") == "1"
  257. if not is_stream:
  258. session_id = request.cookies.get("session_id")
  259. with db_lock:
  260. conn = sqlite3.connect("chat.db", check_same_thread=False)
  261. c = conn.cursor()
  262. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  263. if not c.fetchone():
  264. conn.close()
  265. return {"error": "Unauthorized"}, 401
  266. conn.close()
  267. log(f"Backlog requested by {request.remote_addr}")
  268. with db_lock:
  269. conn = sqlite3.connect("chat.db", check_same_thread=False)
  270. c = conn.cursor()
  271. c.execute("SELECT payload FROM messages ORDER BY id ASC")
  272. rows = c.fetchall()
  273. conn.close()
  274. # return a JSON array of JSON strings
  275. return "[" + ",".join(row[0] for row in rows) + "]", 200
  276. @app.route("/api/room/details", methods=["GET"])
  277. def get_room_details():
  278. log(f"Room details requested by {request.remote_addr}")
  279. return {
  280. "serverIP": custom_address,
  281. "port": PORT,
  282. }, 200
  283. def send_new_user_message(username):
  284. welcome_message = (
  285. f'{{"type": "system", "content": "Now entering room: {username}"}}'
  286. )
  287. for q in clients[:]:
  288. try:
  289. q.put(welcome_message)
  290. except:
  291. clients.remove(q)
  292. def log(msg):
  293. if verbose:
  294. timestamp = "[{:%Y-%m-%d %H:%M:%S}]".format(datetime.now())
  295. print(f"{timestamp}: {msg}")
  296. if __name__ == "__main__":
  297. parser = argparse.ArgumentParser(description="run web server")
  298. parser.add_argument(
  299. "--port",
  300. "-p",
  301. type=int,
  302. help="port to listen on (default: %(default)s)",
  303. default=PORT,
  304. )
  305. parser.add_argument(
  306. "--server",
  307. "-s",
  308. action="store_true",
  309. help="run server in headless mode without opening a browser",
  310. )
  311. parser.add_argument(
  312. "--threads",
  313. "-t",
  314. type=int,
  315. help="number of threads to use (default: %(default)s)",
  316. default=32,
  317. )
  318. parser.add_argument(
  319. "--address",
  320. "-a",
  321. type=str,
  322. help="address displayed to users in browser",
  323. default="0.0.0.0",
  324. )
  325. parser.add_argument(
  326. "--verbose", "-v", action="store_true", help="enable verbose logging"
  327. )
  328. parser.add_argument(
  329. "--admin",
  330. "-A",
  331. action="append",
  332. help="Declare admin username:password",
  333. default=[],
  334. )
  335. args = parser.parse_args()
  336. port = args.port or PORT
  337. open_browser = not args.server
  338. threads = args.threads
  339. custom_address = args.address
  340. verbose = args.verbose
  341. for admin in args.admin:
  342. if ":" in admin:
  343. u, p = admin.split(":", 1)
  344. with db_lock:
  345. conn = sqlite3.connect("chat.db", check_same_thread=False)
  346. c = conn.cursor()
  347. c.execute(
  348. "INSERT OR REPLACE INTO admins (username, password) VALUES (?, ?)",
  349. (u, p),
  350. )
  351. conn.commit()
  352. conn.close()
  353. print(f"\n{CHATNAME} Server running!")
  354. print(f" → Local: http://0.0.0.0:{port}")
  355. if not (1 <= port <= 65535):
  356. log(f"Error: port {port} is out of range (1-65535)")
  357. sys.exit(2)
  358. serve(app, host="0.0.0.0", port=port, threads=threads)