main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import argparse
  2. import base64
  3. import io
  4. import json
  5. import queue
  6. import secrets
  7. import socket
  8. import sqlite3
  9. import sys
  10. import threading
  11. import zipfile
  12. from datetime import datetime
  13. from flask import (
  14. Flask,
  15. make_response,
  16. redirect,
  17. request,
  18. send_file,
  19. send_from_directory,
  20. )
  21. from waitress import serve
  22. # default port if none provided via CLI
  23. PORT = 2004
  24. custom_address = None
  25. verbose = False
  26. CHATNAME = "lainchat"
  27. app = Flask(__name__, static_folder="public", static_url_path="")
  28. db_lock = threading.Lock()
  29. def init_db():
  30. with db_lock:
  31. conn = sqlite3.connect("chat.db", check_same_thread=False)
  32. c = conn.cursor()
  33. c.execute("""
  34. CREATE TABLE IF NOT EXISTS messages (
  35. id INTEGER PRIMARY KEY AUTOINCREMENT,
  36. payload TEXT
  37. )
  38. """)
  39. c.execute("""
  40. CREATE TABLE IF NOT EXISTS users (
  41. username TEXT PRIMARY KEY,
  42. session_id TEXT,
  43. last_active TIMESTAMP
  44. )
  45. """)
  46. c.execute("""
  47. CREATE TABLE IF NOT EXISTS admins (
  48. username TEXT PRIMARY KEY,
  49. password TEXT
  50. )
  51. """)
  52. conn.commit()
  53. conn.close()
  54. init_db()
  55. # each client connection has its own queue
  56. clients = []
  57. @app.route("/")
  58. @app.route("/<path:filename>")
  59. def static_files(filename="index.html"):
  60. if request.args.get("stream") == "1" and filename == "index.html":
  61. return send_from_directory("public", filename)
  62. if "u" in request.args and filename == "index.html":
  63. username = request.args.get("u")
  64. session_id = request.cookies.get("session_id")
  65. with db_lock:
  66. conn = sqlite3.connect("chat.db", check_same_thread=False)
  67. c = conn.cursor()
  68. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  69. row = c.fetchone()
  70. if row and row[0] == session_id:
  71. c.execute(
  72. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  73. (username,),
  74. )
  75. conn.commit()
  76. conn.close()
  77. log(f"{request.remote_addr} connected with username: {username}")
  78. # We can't easily know if they just joined or reloaded, but let's send a welcome
  79. # Actually, let's just let them load the page
  80. return send_from_directory("public", filename)
  81. else:
  82. conn.close()
  83. return redirect("/login.html")
  84. elif filename != "index.html":
  85. return send_from_directory("public", filename)
  86. else:
  87. return redirect("/login.html")
  88. @app.route("/api/check_username", methods=["GET"])
  89. def check_username():
  90. username = request.args.get("u", "").strip()
  91. with db_lock:
  92. conn = sqlite3.connect("chat.db", check_same_thread=False)
  93. c = conn.cursor()
  94. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  95. is_admin = c.fetchone() is not None
  96. conn.close()
  97. return {"status": "admin" if is_admin else "user"}
  98. @app.route("/api/login", methods=["POST"])
  99. def api_login():
  100. data = request.json
  101. username = data.get("u", "").strip()
  102. password = data.get("password", "")
  103. if not username:
  104. return {"error": "Invalid username"}, 400
  105. with db_lock:
  106. conn = sqlite3.connect("chat.db", check_same_thread=False)
  107. c = conn.cursor()
  108. c.execute("DELETE FROM users WHERE last_active < datetime('now', '-1 hour')")
  109. c.execute("SELECT password FROM admins WHERE username = ?", (username,))
  110. admin_row = c.fetchone()
  111. if admin_row:
  112. if admin_row[0] != password:
  113. conn.close()
  114. return {"error": "Invalid password for admin"}, 401
  115. else:
  116. session_id = request.cookies.get("session_id")
  117. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  118. user_row = c.fetchone()
  119. if user_row and user_row[0] != session_id:
  120. conn.close()
  121. return {"error": "Username reserved by another user"}, 403
  122. new_session_id = secrets.token_hex(16)
  123. c.execute(
  124. "INSERT OR REPLACE INTO users (username, session_id, last_active) VALUES (?, ?, datetime('now'))",
  125. (username, new_session_id),
  126. )
  127. conn.commit()
  128. conn.close()
  129. resp = make_response({"success": True})
  130. resp.set_cookie("session_id", new_session_id)
  131. send_new_user_message(username)
  132. return resp
  133. # POST: receives a message from one client and forwards it to all other connections
  134. @app.route("/api/messages", methods=["POST"])
  135. def post_message():
  136. data = request.json
  137. username = data.get("from")
  138. text = data.get("text", "")
  139. session_id = request.cookies.get("session_id")
  140. with db_lock:
  141. conn = sqlite3.connect("chat.db", check_same_thread=False)
  142. c = conn.cursor()
  143. c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
  144. row = c.fetchone()
  145. if not row or row[0] != session_id:
  146. conn.close()
  147. return {"error": "Unauthorized"}, 401
  148. c.execute(
  149. "UPDATE users SET last_active = datetime('now') WHERE username = ?",
  150. (username,),
  151. )
  152. if text.startswith("/kick "):
  153. target = text.split(" ")[1].strip()
  154. c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
  155. if c.fetchone():
  156. c.execute("DELETE FROM users WHERE username = ?", (target,))
  157. conn.commit()
  158. disconnect_msg = f'{{"type": "system", "content": "{target} was disconnected by an admin."}}'
  159. for q in clients[:]:
  160. try:
  161. q.put(disconnect_msg)
  162. except:
  163. pass
  164. conn.close()
  165. return "", 204
  166. conn.commit()
  167. conn.close()
  168. payload = request.get_data(as_text=True)
  169. log(f"Message received by {request.remote_addr}: {payload}")
  170. # Save to sqlite
  171. with db_lock:
  172. conn = sqlite3.connect("chat.db", check_same_thread=False)
  173. c = conn.cursor()
  174. c.execute("INSERT INTO messages (payload) VALUES (?)", (payload,))
  175. # Prune if over 500
  176. c.execute("SELECT COUNT(*) FROM messages")
  177. count = c.fetchone()[0]
  178. if count > 500:
  179. excess = count - 500
  180. c.execute(
  181. "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)",
  182. (excess,),
  183. )
  184. conn.commit()
  185. conn.close()
  186. for q in clients[:]:
  187. try:
  188. q.put(payload)
  189. log(
  190. f"Message from {request.remote_addr} forwarded to {q.qsize()} listener(s)"
  191. )
  192. except:
  193. clients.remove(q)
  194. return "", 204
  195. # GET: all clients listen here, with long-polling
  196. @app.route("/api/messages", methods=["GET"])
  197. def get_messages():
  198. is_stream = request.args.get("stream") == "1"
  199. if not is_stream:
  200. session_id = request.cookies.get("session_id")
  201. with db_lock:
  202. conn = sqlite3.connect('chat.db', check_same_thread=False)
  203. c = conn.cursor()
  204. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  205. if not c.fetchone():
  206. conn.close()
  207. return {"error": "Unauthorized"}, 401
  208. conn.close()
  209. q = queue.Queue()
  210. clients.append(q)
  211. try:
  212. # wait up to 30 seconds for a message
  213. msg = q.get(timeout=30)
  214. return msg, 200
  215. except queue.Empty:
  216. return "", 204 # no message, client retries
  217. finally:
  218. clients.remove(q) # clean up client queue on disconnect
  219. @app.route("/api/download_images", methods=["GET"])
  220. def download_images():
  221. session_id = request.cookies.get("session_id")
  222. with db_lock:
  223. conn = sqlite3.connect("chat.db", check_same_thread=False)
  224. c = conn.cursor()
  225. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  226. if not c.fetchone():
  227. conn.close()
  228. return {"error": "Unauthorized"}, 401
  229. c.execute("SELECT payload FROM messages ORDER BY id DESC")
  230. rows = c.fetchall()
  231. conn.close()
  232. memory_file = io.BytesIO()
  233. with zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) as zf:
  234. index = 0
  235. for row in rows:
  236. try:
  237. data = json.loads(row[0])
  238. doodle = data.get("doodle", "")
  239. if doodle.startswith("data:image/png;base64,"):
  240. b64_data = doodle.split(",")[1]
  241. img_bytes = base64.b64decode(b64_data)
  242. zf.writestr(f"{index}.png", img_bytes)
  243. index += 1
  244. except Exception:
  245. pass
  246. memory_file.seek(0)
  247. return send_file(
  248. memory_file,
  249. mimetype="application/zip",
  250. as_attachment=True,
  251. download_name=f"{CHATNAME}_archive.zip",
  252. )
  253. @app.route("/api/backlog", methods=["GET"])
  254. def get_backlog():
  255. is_stream = request.args.get("stream") == "1"
  256. if not is_stream:
  257. session_id = request.cookies.get("session_id")
  258. with db_lock:
  259. conn = sqlite3.connect('chat.db', check_same_thread=False)
  260. c = conn.cursor()
  261. c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
  262. if not c.fetchone():
  263. conn.close()
  264. return {"error": "Unauthorized"}, 401
  265. conn.close()
  266. log(f"Backlog requested by {request.remote_addr}")
  267. with db_lock:
  268. conn = sqlite3.connect("chat.db", check_same_thread=False)
  269. c = conn.cursor()
  270. c.execute("SELECT payload FROM messages ORDER BY id ASC")
  271. rows = c.fetchall()
  272. conn.close()
  273. # return a JSON array of JSON strings
  274. return "[" + ",".join(row[0] for row in rows) + "]", 200
  275. @app.route("/api/room/details", methods=["GET"])
  276. def get_room_details():
  277. log(f"Room details requested by {request.remote_addr}")
  278. return {
  279. "serverIP": custom_address,
  280. "port": PORT,
  281. }, 200
  282. def send_new_user_message(username):
  283. welcome_message = (
  284. f'{{"type": "system", "content": "Now entering room: {username}"}}'
  285. )
  286. for q in clients[:]:
  287. try:
  288. q.put(welcome_message)
  289. except:
  290. clients.remove(q)
  291. def log(msg):
  292. if verbose:
  293. timestamp = "[{:%Y-%m-%d %H:%M:%S}]".format(datetime.now())
  294. print(f"{timestamp}: {msg}")
  295. if __name__ == "__main__":
  296. parser = argparse.ArgumentParser(description="run web server")
  297. parser.add_argument(
  298. "--port",
  299. "-p",
  300. type=int,
  301. help="port to listen on (default: %(default)s)",
  302. default=PORT,
  303. )
  304. parser.add_argument(
  305. "--server",
  306. "-s",
  307. action="store_true",
  308. help="run server in headless mode without opening a browser",
  309. )
  310. parser.add_argument(
  311. "--threads",
  312. "-t",
  313. type=int,
  314. help="number of threads to use (default: %(default)s)",
  315. default=32,
  316. )
  317. parser.add_argument(
  318. "--address",
  319. "-a",
  320. type=str,
  321. help="address displayed to users in browser",
  322. default="0.0.0.0",
  323. )
  324. parser.add_argument(
  325. "--verbose", "-v", action="store_true", help="enable verbose logging"
  326. )
  327. parser.add_argument(
  328. "--admin",
  329. "-A",
  330. action="append",
  331. help="Declare admin username:password",
  332. default=[],
  333. )
  334. args = parser.parse_args()
  335. port = args.port or PORT
  336. open_browser = not args.server
  337. threads = args.threads
  338. custom_address = args.address
  339. verbose = args.verbose
  340. for admin in args.admin:
  341. if ":" in admin:
  342. u, p = admin.split(":", 1)
  343. with db_lock:
  344. conn = sqlite3.connect("chat.db", check_same_thread=False)
  345. c = conn.cursor()
  346. c.execute(
  347. "INSERT OR REPLACE INTO admins (username, password) VALUES (?, ?)",
  348. (u, p),
  349. )
  350. conn.commit()
  351. conn.close()
  352. print(f"\n{CHATNAME} Server running!")
  353. print(f" → Local: http://0.0.0.0:{port}")
  354. if not (1 <= port <= 65535):
  355. log(f"Error: port {port} is out of range (1-65535)")
  356. sys.exit(2)
  357. serve(app, host="0.0.0.0", port=port, threads=threads)