main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import argparse
  2. import base64
  3. import io
  4. import json
  5. import queue
  6. import secrets
  7. import socket
  8. import sqlite3
  9. import sys
  10. import threading
  11. import zipfile
  12. from datetime import datetime
  13. from flask import (
  14. Flask,
  15. make_response,
  16. redirect,
  17. request,
  18. send_file,
  19. send_from_directory,
  20. )
  21. from waitress import serve
  22. # default port if none provided via CLI
  23. PORT = 2004
  24. custom_address = None
  25. verbose = False
  26. CHATNAME = "lainchat"
  27. app = Flask(__name__, static_folder="public", static_url_path="")
  28. db_lock = threading.Lock()
  29. def init_db():
  30. with db_lock:
  31. conn = sqlite3.connect("chat.db", check_same_thread=False)
  32. c = conn.cursor()
  33. c.execute("""
  34. CREATE TABLE IF NOT EXISTS messages (
  35. id INTEGER PRIMARY KEY AUTOINCREMENT,
  36. payload TEXT
  37. )
  38. """)
  39. c.execute("""
  40. CREATE TABLE IF NOT EXISTS users (
  41. username TEXT PRIMARY KEY,
  42. session_id TEXT,
  43. last_active TIMESTAMP
  44. )
  45. """)
  46. c.execute("""
  47. CREATE TABLE IF NOT EXISTS admins (
  48. username TEXT PRIMARY KEY,
  49. password TEXT
  50. )
  51. """)
  52. conn.commit()
  53. conn.close()
  54. init_db()
  55. # each client connection has its own queue
  56. clients = []
  57. @app.route("/")
  58. @app.route("/<path:filename>")
  59. def static_files(filename="index.html"):
  60. if "u" in request.args and filename == "index.html":
  61. username = request.args.get("u")
  62. session_id = request.cookies.get("session_id")
  63. with db_lock:
  64. conn = sqlite3.connect("chat.db", check_same_thread=False)
  65. c = conn.cursor()
  66. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  67. row = c.fetchone()
  68. if row and row[0] == session_id:
  69. c.execute(
  70. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  71. (username,),
  72. )
  73. conn.commit()
  74. conn.close()
  75. log(f"{request.remote_addr} connected with username: {username}")
  76. # We can't easily know if they just joined or reloaded, but let's send a welcome
  77. # Actually, let's just let them load the page
  78. return send_from_directory("public", filename)
  79. else:
  80. conn.close()
  81. return redirect("/login.html")
  82. elif filename != "index.html":
  83. return send_from_directory("public", filename)
  84. else:
  85. return redirect("/login.html")
  86. @app.route("/api/check_username", methods=["GET"])
  87. def check_username():
  88. username = request.args.get("u", "").strip()
  89. with db_lock:
  90. conn = sqlite3.connect("chat.db", check_same_thread=False)
  91. c = conn.cursor()
  92. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  93. is_admin = c.fetchone() is not None
  94. conn.close()
  95. return {"status": "admin" if is_admin else "user"}
  96. @app.route("/api/login", methods=["POST"])
  97. def api_login():
  98. data = request.json
  99. username = data.get("u", "").strip()
  100. password = data.get("password", "")
  101. if not username:
  102. return {"error": "Invalid username"}, 400
  103. with db_lock:
  104. conn = sqlite3.connect("chat.db", check_same_thread=False)
  105. c = conn.cursor()
  106. c.execute("DELETE FROM users WHERE last_active < datetime('now', '-1 hour')")
  107. c.execute("SELECT password FROM admins WHERE username = ?", (username,))
  108. admin_row = c.fetchone()
  109. if admin_row:
  110. if admin_row[0] != password:
  111. conn.close()
  112. return {"error": "Invalid password for admin"}, 401
  113. else:
  114. session_id = request.cookies.get("session_id")
  115. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  116. user_row = c.fetchone()
  117. if user_row and user_row[0] != session_id:
  118. conn.close()
  119. return {"error": "Username reserved by another user"}, 403
  120. new_session_id = secrets.token_hex(16)
  121. c.execute(
  122. "INSERT OR REPLACE INTO users (username, session_id, last_active) VALUES (?, ?, datetime('now'))",
  123. (username, new_session_id),
  124. )
  125. conn.commit()
  126. conn.close()
  127. resp = make_response({"success": True})
  128. resp.set_cookie("session_id", new_session_id)
  129. send_new_user_message(username)
  130. return resp
  131. # POST: receives a message from one client and forwards it to all other connections
  132. @app.route("/api/messages", methods=["POST"])
  133. def post_message():
  134. data = request.json
  135. username = data.get("from")
  136. text = data.get("text", "")
  137. session_id = request.cookies.get("session_id")
  138. with db_lock:
  139. conn = sqlite3.connect("chat.db", check_same_thread=False)
  140. c = conn.cursor()
  141. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  142. row = c.fetchone()
  143. if not row or row[0] != session_id:
  144. conn.close()
  145. return {"error": "Unauthorized"}, 401
  146. c.execute(
  147. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  148. (username,),
  149. )
  150. if text.startswith("/kick "):
  151. target = text.split(" ")[1].strip()
  152. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  153. if c.fetchone():
  154. c.execute("DELETE FROM users WHERE username = ?", (target,))
  155. conn.commit()
  156. disconnect_msg = f'{{"type": "system", "content": "{target} was disconnected by an admin."}}'
  157. for q in clients[:]:
  158. try:
  159. q.put(disconnect_msg)
  160. except:
  161. pass
  162. conn.close()
  163. return "", 204
  164. conn.commit()
  165. conn.close()
  166. payload = request.get_data(as_text=True)
  167. log(f"Message received by {request.remote_addr}: {payload}")
  168. # Save to sqlite
  169. with db_lock:
  170. conn = sqlite3.connect("chat.db", check_same_thread=False)
  171. c = conn.cursor()
  172. c.execute("INSERT INTO messages (payload) VALUES (?)", (payload,))
  173. # Prune if over 500
  174. c.execute("SELECT COUNT(*) FROM messages")
  175. count = c.fetchone()[0]
  176. if count > 500:
  177. excess = count - 500
  178. c.execute(
  179. "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)",
  180. (excess,),
  181. )
  182. conn.commit()
  183. conn.close()
  184. for q in clients[:]:
  185. try:
  186. q.put(payload)
  187. log(
  188. f"Message from {request.remote_addr} forwarded to {q.qsize()} listener(s)"
  189. )
  190. except:
  191. clients.remove(q)
  192. return "", 204
  193. # GET: all clients listen here, with long-polling
  194. @app.route("/api/messages", methods=["GET"])
  195. def get_messages():
  196. session_id = request.cookies.get("session_id")
  197. with db_lock:
  198. conn = sqlite3.connect("chat.db", check_same_thread=False)
  199. c = conn.cursor()
  200. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  201. if not c.fetchone():
  202. conn.close()
  203. return {"error": "Unauthorized"}, 401
  204. conn.close()
  205. q = queue.Queue()
  206. clients.append(q)
  207. try:
  208. # wait up to 30 seconds for a message
  209. msg = q.get(timeout=30)
  210. return msg, 200
  211. except queue.Empty:
  212. return "", 204 # no message, client retries
  213. finally:
  214. clients.remove(q) # clean up client queue on disconnect
  215. @app.route("/api/download_images", methods=["GET"])
  216. def download_images():
  217. session_id = request.cookies.get("session_id")
  218. with db_lock:
  219. conn = sqlite3.connect("chat.db", check_same_thread=False)
  220. c = conn.cursor()
  221. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  222. if not c.fetchone():
  223. conn.close()
  224. return {"error": "Unauthorized"}, 401
  225. c.execute("SELECT payload FROM messages ORDER BY id DESC")
  226. rows = c.fetchall()
  227. conn.close()
  228. memory_file = io.BytesIO()
  229. with zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) as zf:
  230. index = 0
  231. for row in rows:
  232. try:
  233. data = json.loads(row[0])
  234. doodle = data.get("doodle", "")
  235. if doodle.startswith("data:image/png;base64,"):
  236. b64_data = doodle.split(",")[1]
  237. img_bytes = base64.b64decode(b64_data)
  238. zf.writestr(f"{index}.png", img_bytes)
  239. index += 1
  240. except Exception:
  241. pass
  242. memory_file.seek(0)
  243. return send_file(
  244. memory_file,
  245. mimetype="application/zip",
  246. as_attachment=True,
  247. download_name=f"{CHATNAME}_archive.zip",
  248. )
  249. @app.route("/api/backlog", methods=["GET"])
  250. def get_backlog():
  251. log(f"Backlog requested by {request.remote_addr}")
  252. with db_lock:
  253. conn = sqlite3.connect("chat.db", check_same_thread=False)
  254. c = conn.cursor()
  255. c.execute("SELECT payload FROM messages ORDER BY id ASC")
  256. rows = c.fetchall()
  257. conn.close()
  258. # return a JSON array of JSON strings
  259. return "[" + ",".join(row[0] for row in rows) + "]", 200
  260. @app.route("/api/room/details", methods=["GET"])
  261. def get_room_details():
  262. log(f"Room details requested by {request.remote_addr}")
  263. return {
  264. "serverIP": custom_address,
  265. "port": PORT,
  266. }, 200
  267. def send_new_user_message(username):
  268. welcome_message = (
  269. f'{{"type": "system", "content": "Now entering room: {username}"}}'
  270. )
  271. for q in clients[:]:
  272. try:
  273. q.put(welcome_message)
  274. except:
  275. clients.remove(q)
  276. def log(msg):
  277. if verbose:
  278. timestamp = "[{:%Y-%m-%d %H:%M:%S}]".format(datetime.now())
  279. print(f"{timestamp}: {msg}")
  280. if __name__ == "__main__":
  281. parser = argparse.ArgumentParser(description="run web server")
  282. parser.add_argument(
  283. "--port",
  284. "-p",
  285. type=int,
  286. help="port to listen on (default: %(default)s)",
  287. default=PORT,
  288. )
  289. parser.add_argument(
  290. "--server",
  291. "-s",
  292. action="store_true",
  293. help="run server in headless mode without opening a browser",
  294. )
  295. parser.add_argument(
  296. "--threads",
  297. "-t",
  298. type=int,
  299. help="number of threads to use (default: %(default)s)",
  300. default=32,
  301. )
  302. parser.add_argument(
  303. "--address",
  304. "-a",
  305. type=str,
  306. help="address displayed to users in browser",
  307. default="0.0.0.0",
  308. )
  309. parser.add_argument(
  310. "--verbose", "-v", action="store_true", help="enable verbose logging"
  311. )
  312. parser.add_argument(
  313. "--admin",
  314. "-A",
  315. action="append",
  316. help="Declare admin username:password",
  317. default=[],
  318. )
  319. args = parser.parse_args()
  320. port = args.port or PORT
  321. open_browser = not args.server
  322. threads = args.threads
  323. custom_address = args.address
  324. verbose = args.verbose
  325. for admin in args.admin:
  326. if ":" in admin:
  327. u, p = admin.split(":", 1)
  328. with db_lock:
  329. conn = sqlite3.connect("chat.db", check_same_thread=False)
  330. c = conn.cursor()
  331. c.execute(
  332. "INSERT OR REPLACE INTO admins (username, password) VALUES (?, ?)",
  333. (u, p),
  334. )
  335. conn.commit()
  336. conn.close()
  337. print(f"\n{CHATNAME} Server running!")
  338. print(f" → Local: http://0.0.0.0:{port}")
  339. if not (1 <= port <= 65535):
  340. log(f"Error: port {port} is out of range (1-65535)")
  341. sys.exit(2)
  342. serve(app, host="0.0.0.0", port=port, threads=threads)