main.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import argparse
  2. import queue
  3. import socket
  4. import sys
  5. import sqlite3
  6. import threading
  7. from datetime import datetime
  8. from flask import Flask, request, send_from_directory
  9. from waitress import serve
  10. # default port if none provided via CLI
  11. PORT = 2004
  12. custom_address = None
  13. verbose = False
  14. def get_local_ip():
  15. """Get the local IP address of the machine."""
  16. try:
  17. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  18. s.connect(("8.8.8.8", 80)) # doesn't actually send data
  19. local_ip = s.getsockname()[0]
  20. s.close()
  21. return local_ip
  22. except Exception:
  23. return "localhost"
  24. app = Flask(__name__, static_folder="public", static_url_path="")
  25. db_lock = threading.Lock()
  26. def init_db():
  27. with db_lock:
  28. conn = sqlite3.connect('chat.db', check_same_thread=False)
  29. c = conn.cursor()
  30. c.execute('''
  31. CREATE TABLE IF NOT EXISTS messages (
  32. id INTEGER PRIMARY KEY AUTOINCREMENT,
  33. payload TEXT
  34. )
  35. ''')
  36. conn.commit()
  37. conn.close()
  38. init_db()
  39. # each client connection has its own queue
  40. clients = []
  41. @app.route("/")
  42. @app.route("/<path:filename>")
  43. def static_files(filename="index.html"):
  44. if "u" in request.args:
  45. username = request.args.get("u")
  46. log(f"{request.remote_addr} connected with username: {username}")
  47. send_new_user_message(username)
  48. return send_from_directory("public", filename)
  49. else:
  50. log(
  51. f"{request.remote_addr} connected with no username; redirecting to login page"
  52. )
  53. return send_from_directory(
  54. "public", filename if filename != "index.html" else "login.html"
  55. )
  56. # POST: receives a message from one client and forwards it to all other connections
  57. @app.route("/api/messages", methods=["POST"])
  58. def post_message():
  59. payload = request.get_data(as_text=True)
  60. log(f"Message received by {request.remote_addr}: {payload}")
  61. # Save to sqlite
  62. with db_lock:
  63. conn = sqlite3.connect('chat.db', check_same_thread=False)
  64. c = conn.cursor()
  65. c.execute('INSERT INTO messages (payload) VALUES (?)', (payload,))
  66. # Prune if over 500
  67. c.execute('SELECT COUNT(*) FROM messages')
  68. count = c.fetchone()[0]
  69. if count > 500:
  70. excess = count - 500
  71. c.execute('DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)', (excess,))
  72. conn.commit()
  73. conn.close()
  74. for q in clients[:]:
  75. try:
  76. q.put(payload)
  77. log(
  78. f"Message from {request.remote_addr} forwarded to {q.qsize()} listener(s)"
  79. )
  80. except:
  81. clients.remove(q)
  82. return "", 204
  83. # GET: all clients listen here, with long-polling
  84. @app.route("/api/messages", methods=["GET"])
  85. def get_messages():
  86. q = queue.Queue()
  87. clients.append(q)
  88. try:
  89. # wait up to 30 seconds for a message
  90. msg = q.get(timeout=30)
  91. return msg, 200
  92. except queue.Empty:
  93. return "", 204 # no message, client retries
  94. finally:
  95. clients.remove(q) # clean up client queue on disconnect
  96. @app.route("/api/backlog", methods=["GET"])
  97. def get_backlog():
  98. log(f"Backlog requested by {request.remote_addr}")
  99. with db_lock:
  100. conn = sqlite3.connect('chat.db', check_same_thread=False)
  101. c = conn.cursor()
  102. c.execute('SELECT payload FROM messages ORDER BY id ASC')
  103. rows = c.fetchall()
  104. conn.close()
  105. # return a JSON array of JSON strings
  106. return "[" + ",".join(row[0] for row in rows) + "]", 200
  107. @app.route("/api/room/details", methods=["GET"])
  108. def get_room_details():
  109. log(f"Room details requested by {request.remote_addr}")
  110. return {
  111. "serverIP": custom_address or get_local_ip(),
  112. "port": PORT,
  113. }, 200
  114. def send_new_user_message(username):
  115. welcome_message = (
  116. f'{{"type": "system", "content": "Now entering room: {username}"}}'
  117. )
  118. for q in clients[:]:
  119. try:
  120. q.put(welcome_message)
  121. except:
  122. clients.remove(q)
  123. def log(msg):
  124. if verbose:
  125. timestamp = "[{:%Y-%m-%d %H:%M:%S}]".format(datetime.now())
  126. print(f"{timestamp}: {msg}")
  127. if __name__ == "__main__":
  128. import socket
  129. parser = argparse.ArgumentParser(description="run pctochat web server")
  130. parser.add_argument(
  131. "--port",
  132. "-p",
  133. type=int,
  134. help="port to listen on (default: %(default)s)",
  135. default=PORT,
  136. )
  137. parser.add_argument(
  138. "--server",
  139. "-s",
  140. action="store_true",
  141. help="run server in headless mode without opening a browser",
  142. )
  143. parser.add_argument(
  144. "--threads",
  145. "-t",
  146. type=int,
  147. help="number of threads to use (default: %(default)s)",
  148. default=16,
  149. )
  150. parser.add_argument(
  151. "--address",
  152. "-a",
  153. type=str,
  154. help="address displayed to users in browser",
  155. default="0.0.0.0",
  156. )
  157. parser.add_argument(
  158. "--verbose", "-v", action="store_true", help="enable verbose logging"
  159. )
  160. args = parser.parse_args()
  161. port = args.port or PORT
  162. open_browser = not args.server
  163. threads = args.threads
  164. custom_address = args.address
  165. verbose = args.verbose
  166. # local_ip = get_local_ip()
  167. print(f"\nServer running!")
  168. print(f" → Local: http://127.0.0.1:{port}")
  169. # print(f" → Network: http://{local_ip}:{port}\n")
  170. if not (1 <= port <= 65535):
  171. log(f"Error: port {port} is out of range (1-65535)")
  172. sys.exit(2)
  173. serve(app, host="0.0.0.0", port=port, threads=threads)