lainlives 3 днів тому
батько
коміт
d8e7670993
6 змінених файлів з 329 додано та 62 видалено
  1. 6 3
      README.md
  2. 237 45
      main.py
  3. 2 1
      public/index.html
  4. 9 0
      public/js/canvas.js
  5. 15 7
      public/js/service.js
  6. 60 6
      public/login.html

+ 6 - 3
README.md

@@ -9,12 +9,15 @@ Yet Another PictoChat Clone
 - Landscape and Portrait view
 - Transparent background/floating OBS source mode
 - All the colors
-- Crops white pixels off top/bottom maximizing chat backlog
+- Crops unpainted pixels off top/bottom maximizing chat backlog
 - No usercount limits
+- Basic username reservation
+- Admin accounts, specified at server launch, they can remove users
+- Can download the entire current database as a zip
+
 
 ## Missing Features I would like to implement
-- Anonymous auth, currently all users can sign in with the same username...
-- No ability to kick users 
+
 - Canvas size customizations at server creation.
 
 

+ 237 - 45
main.py

@@ -1,12 +1,24 @@
 import argparse
+import base64
+import io
+import json
 import queue
+import secrets
 import socket
-import sys
 import sqlite3
+import sys
 import threading
+import zipfile
 from datetime import datetime
 
-from flask import Flask, request, send_from_directory
+from flask import (
+    Flask,
+    make_response,
+    redirect,
+    request,
+    send_file,
+    send_from_directory,
+)
 from waitress import serve
 
 # default port if none provided via CLI
@@ -14,37 +26,40 @@ PORT = 2004
 
 custom_address = None
 verbose = False
-
-
-def get_local_ip():
-    """Get the local IP address of the machine."""
-    try:
-        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-        s.connect(("8.8.8.8", 80))  # doesn't actually send data
-        local_ip = s.getsockname()[0]
-        s.close()
-        return local_ip
-    except Exception:
-        return "localhost"
-
+CHATNAME = "lainchat"
 
 app = Flask(__name__, static_folder="public", static_url_path="")
 
 db_lock = threading.Lock()
 
+
 def init_db():
     with db_lock:
-        conn = sqlite3.connect('chat.db', check_same_thread=False)
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
         c = conn.cursor()
-        c.execute('''
+        c.execute("""
             CREATE TABLE IF NOT EXISTS messages (
                 id INTEGER PRIMARY KEY AUTOINCREMENT,
                 payload TEXT
             )
-        ''')
+        """)
+        c.execute("""
+            CREATE TABLE IF NOT EXISTS users (
+                username TEXT PRIMARY KEY,
+                session_id TEXT,
+                last_active TIMESTAMP
+            )
+        """)
+        c.execute("""
+            CREATE TABLE IF NOT EXISTS admins (
+                username TEXT PRIMARY KEY,
+                password TEXT
+            )
+        """)
         conn.commit()
         conn.close()
 
+
 init_db()
 
 # each client connection has its own queue
@@ -54,39 +69,151 @@ clients = []
 @app.route("/")
 @app.route("/<path:filename>")
 def static_files(filename="index.html"):
-    if "u" in request.args:
+    if "u" in request.args and filename == "index.html":
         username = request.args.get("u")
-        log(f"{request.remote_addr} connected with username: {username}")
-        send_new_user_message(username)
+        session_id = request.cookies.get("session_id")
+
+        with db_lock:
+            conn = sqlite3.connect("chat.db", check_same_thread=False)
+            c = conn.cursor()
+            c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
+            row = c.fetchone()
+
+            if row and row[0] == session_id:
+                c.execute(
+                    "UPDATE users SET last_active = datetime('now') WHERE username = ?",
+                    (username,),
+                )
+                conn.commit()
+                conn.close()
+                log(f"{request.remote_addr} connected with username: {username}")
+                # We can't easily know if they just joined or reloaded, but let's send a welcome
+                # Actually, let's just let them load the page
+                return send_from_directory("public", filename)
+            else:
+                conn.close()
+                return redirect("/login.html")
+    elif filename != "index.html":
         return send_from_directory("public", filename)
     else:
-        log(
-            f"{request.remote_addr} connected with no username; redirecting to login page"
-        )
-        return send_from_directory(
-            "public", filename if filename != "index.html" else "login.html"
+        return redirect("/login.html")
+
+
+@app.route("/api/check_username", methods=["GET"])
+def check_username():
+    username = request.args.get("u", "").strip()
+    with db_lock:
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
+        c = conn.cursor()
+        c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
+        is_admin = c.fetchone() is not None
+        conn.close()
+        return {"status": "admin" if is_admin else "user"}
+
+
+@app.route("/api/login", methods=["POST"])
+def api_login():
+    data = request.json
+    username = data.get("u", "").strip()
+    password = data.get("password", "")
+
+    if not username:
+        return {"error": "Invalid username"}, 400
+
+    with db_lock:
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
+        c = conn.cursor()
+        c.execute("DELETE FROM users WHERE last_active < datetime('now', '-1 hour')")
+
+        c.execute("SELECT password FROM admins WHERE username = ?", (username,))
+        admin_row = c.fetchone()
+
+        if admin_row:
+            if admin_row[0] != password:
+                conn.close()
+                return {"error": "Invalid password for admin"}, 401
+        else:
+            session_id = request.cookies.get("session_id")
+            c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
+            user_row = c.fetchone()
+            if user_row and user_row[0] != session_id:
+                conn.close()
+                return {"error": "Username reserved by another user"}, 403
+
+        new_session_id = secrets.token_hex(16)
+        c.execute(
+            "INSERT OR REPLACE INTO users (username, session_id, last_active) VALUES (?, ?, datetime('now'))",
+            (username, new_session_id),
         )
+        conn.commit()
+        conn.close()
+
+        resp = make_response({"success": True})
+        resp.set_cookie("session_id", new_session_id)
+
+        send_new_user_message(username)
+        return resp
 
 
 # POST: receives a message from one client and forwards it to all other connections
 @app.route("/api/messages", methods=["POST"])
 def post_message():
+    data = request.json
+    username = data.get("from")
+    text = data.get("text", "")
+    session_id = request.cookies.get("session_id")
+
+    with db_lock:
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
+        c = conn.cursor()
+        c.execute("SELECT session_id FROM users WHERE username = ?", (username,))
+        row = c.fetchone()
+        if not row or row[0] != session_id:
+            conn.close()
+            return {"error": "Unauthorized"}, 401
+
+        c.execute(
+            "UPDATE users SET last_active = datetime('now') WHERE username = ?",
+            (username,),
+        )
+
+        if text.startswith("/kick "):
+            target = text.split(" ")[1].strip()
+            c.execute("SELECT 1 FROM admins WHERE username = ?", (username,))
+            if c.fetchone():
+                c.execute("DELETE FROM users WHERE username = ?", (target,))
+                conn.commit()
+                disconnect_msg = f'{{"type": "system", "content": "{target} was disconnected by an admin."}}'
+                for q in clients[:]:
+                    try:
+                        q.put(disconnect_msg)
+                    except:
+                        pass
+                conn.close()
+                return "", 204
+
+        conn.commit()
+        conn.close()
+
     payload = request.get_data(as_text=True)
     log(f"Message received by {request.remote_addr}: {payload}")
-    
+
     # Save to sqlite
     with db_lock:
-        conn = sqlite3.connect('chat.db', check_same_thread=False)
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
         c = conn.cursor()
-        c.execute('INSERT INTO messages (payload) VALUES (?)', (payload,))
-        
+        c.execute("INSERT INTO messages (payload) VALUES (?)", (payload,))
+
         # Prune if over 500
-        c.execute('SELECT COUNT(*) FROM messages')
+        c.execute("SELECT COUNT(*) FROM messages")
         count = c.fetchone()[0]
         if count > 500:
             excess = count - 500
-            c.execute('DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)', (excess,))
-        
+            c.execute(
+                "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT ?)",
+                (excess,),
+            )
+
         conn.commit()
         conn.close()
 
@@ -104,6 +231,16 @@ def post_message():
 # GET: all clients listen here, with long-polling
 @app.route("/api/messages", methods=["GET"])
 def get_messages():
+    session_id = request.cookies.get("session_id")
+    with db_lock:
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
+        c = conn.cursor()
+        c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
+        if not c.fetchone():
+            conn.close()
+            return {"error": "Unauthorized"}, 401
+        conn.close()
+
     q = queue.Queue()
     clients.append(q)
     try:
@@ -115,13 +252,53 @@ def get_messages():
     finally:
         clients.remove(q)  # clean up client queue on disconnect
 
+
+@app.route("/api/download_images", methods=["GET"])
+def download_images():
+    session_id = request.cookies.get("session_id")
+    with db_lock:
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
+        c = conn.cursor()
+        c.execute("SELECT 1 FROM users WHERE session_id = ?", (session_id,))
+        if not c.fetchone():
+            conn.close()
+            return {"error": "Unauthorized"}, 401
+
+        c.execute("SELECT payload FROM messages ORDER BY id DESC")
+        rows = c.fetchall()
+        conn.close()
+
+    memory_file = io.BytesIO()
+    with zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) as zf:
+        index = 0
+        for row in rows:
+            try:
+                data = json.loads(row[0])
+                doodle = data.get("doodle", "")
+                if doodle.startswith("data:image/png;base64,"):
+                    b64_data = doodle.split(",")[1]
+                    img_bytes = base64.b64decode(b64_data)
+                    zf.writestr(f"{index}.png", img_bytes)
+                    index += 1
+            except Exception:
+                pass
+
+    memory_file.seek(0)
+    return send_file(
+        memory_file,
+        mimetype="application/zip",
+        as_attachment=True,
+        download_name=f"{CHATNAME}_archive.zip",
+    )
+
+
 @app.route("/api/backlog", methods=["GET"])
 def get_backlog():
     log(f"Backlog requested by {request.remote_addr}")
     with db_lock:
-        conn = sqlite3.connect('chat.db', check_same_thread=False)
+        conn = sqlite3.connect("chat.db", check_same_thread=False)
         c = conn.cursor()
-        c.execute('SELECT payload FROM messages ORDER BY id ASC')
+        c.execute("SELECT payload FROM messages ORDER BY id ASC")
         rows = c.fetchall()
         conn.close()
     # return a JSON array of JSON strings
@@ -132,7 +309,7 @@ def get_backlog():
 def get_room_details():
     log(f"Room details requested by {request.remote_addr}")
     return {
-        "serverIP": custom_address or get_local_ip(),
+        "serverIP": custom_address,
         "port": PORT,
     }, 200
 
@@ -155,9 +332,7 @@ def log(msg):
 
 
 if __name__ == "__main__":
-    import socket
-
-    parser = argparse.ArgumentParser(description="run pctochat web server")
+    parser = argparse.ArgumentParser(description="run web server")
     parser.add_argument(
         "--port",
         "-p",
@@ -176,7 +351,7 @@ if __name__ == "__main__":
         "-t",
         type=int,
         help="number of threads to use (default: %(default)s)",
-        default=16,
+        default=32,
     )
     parser.add_argument(
         "--address",
@@ -188,6 +363,13 @@ if __name__ == "__main__":
     parser.add_argument(
         "--verbose", "-v", action="store_true", help="enable verbose logging"
     )
+    parser.add_argument(
+        "--admin",
+        "-A",
+        action="append",
+        help="Declare admin username:password",
+        default=[],
+    )
     args = parser.parse_args()
 
     port = args.port or PORT
@@ -196,11 +378,21 @@ if __name__ == "__main__":
     custom_address = args.address
     verbose = args.verbose
 
-    # local_ip = get_local_ip()
-
-    print(f"\nServer running!")
-    print(f" → Local:   http://127.0.0.1:{port}")
-    # print(f" → Network: http://{local_ip}:{port}\n")
+    for admin in args.admin:
+        if ":" in admin:
+            u, p = admin.split(":", 1)
+            with db_lock:
+                conn = sqlite3.connect("chat.db", check_same_thread=False)
+                c = conn.cursor()
+                c.execute(
+                    "INSERT OR REPLACE INTO admins (username, password) VALUES (?, ?)",
+                    (u, p),
+                )
+                conn.commit()
+                conn.close()
+
+    print(f"\n{CHATNAME} Server running!")
+    print(f" → Local:   http://0.0.0.0:{port}")
 
     if not (1 <= port <= 65535):
         log(f"Error: port {port} is out of range (1-65535)")

+ 2 - 1
public/index.html

@@ -3,7 +3,7 @@
 <head>
     <meta charset="UTF-8">
     <meta name="viewport" content="width=device-width, initial-scale=1.0">
-    <title>Modern PictoChat</title>
+    <title> </title>
     <link rel="stylesheet" href="styles.css">
     <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap" rel="stylesheet">
 </head>
@@ -52,6 +52,7 @@
                 </div>
                 <div id="form-button-row">
                     <button id="theme-btn" title="Toggle Light/Dark Theme">☀</button>
+                    <button id="download-btn" title="Download Image Archive">💾</button>
                     <button id="expand-btn">expand ⤢</button>
                     <button id="clear">clear</button>
                     <button id="copy">copy</button>

+ 9 - 0
public/js/canvas.js

@@ -231,3 +231,12 @@ if (themeBtn) {
     }
   });
 }
+
+// Download Archive Button
+const downloadBtn = document.getElementById("download-btn");
+if (downloadBtn) {
+  downloadBtn.addEventListener("click", (e) => {
+    e.preventDefault();
+    window.location.href = "/api/download_images";
+  });
+}

+ 15 - 7
public/js/service.js

@@ -70,13 +70,17 @@ function send() {
     color: userColor,
     timestamp: new Date().toISOString(),
   };
-  fetch("/api/messages", {
-    method: "POST",
-    headers: {
-      "Content-Type": "application/json",
-    },
-    body: JSON.stringify(payload),
-  });
+    fetch("/api/messages", {
+      method: "POST",
+      headers: {
+        "Content-Type": "application/json",
+      },
+      body: JSON.stringify(payload),
+    }).then(response => {
+      if (response.status === 401) {
+        window.location.href = "/";
+      }
+    });
   // clear canvas after sending
   canvas.width = canvas.width;
   window.clearCurrentText();
@@ -183,6 +187,10 @@ function handleMessage(msg) {
 async function listen() {
   try {
     const response = await fetch("/api/messages");
+    if (response.status === 401) {
+      window.location.href = "/";
+      return;
+    }
     if (response.status === 200) {
       const msg = await response.text();
       handleMessage(msg);

+ 60 - 6
public/login.html

@@ -1,23 +1,77 @@
 <!DOCTYPE html>
 <html lang="en">
-
 <head>
     <meta charset="UTF-8">
     <meta name="viewport" content="width=device-width, initial-scale=1.0">
-    <title>pctochat</title>
+    <title>Modern PictoChat</title>
     <link rel="stylesheet" href="styles.css">
+    <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap" rel="stylesheet">
 </head>
-
 <body>
     <div class="login container">
         <h1>join room</h1>
         <span id="join-info">other people can join</span>
-        <form action="/" method="get">
-            <input type="text" id="username" name="u" placeholder="your name" required>
+        <form id="login-form">
+            <input type="text" id="username" name="u" placeholder="your name" required autocomplete="off">
+            <input type="password" id="password" name="p" placeholder="password" style="display: none;">
             <button type="submit">join</button>
         </form>
+        <span id="login-error" style="color: #ff5555; margin-top: 10px; font-size: 14px; text-align: center;"></span>
     </div>
+    <script>
+        const form = document.getElementById("login-form");
+        const userInp = document.getElementById("username");
+        const passInp = document.getElementById("password");
+        const errSpan = document.getElementById("login-error");
+        
+        let checkedUsername = "";
+        let isAdmin = false;
+        
+        form.addEventListener("submit", async (e) => {
+            e.preventDefault();
+            errSpan.innerText = "";
+            const u = userInp.value.trim();
+            if (!u) return;
+            
+            if (u !== checkedUsername) {
+                const res = await fetch("/api/check_username?u=" + encodeURIComponent(u));
+                const data = await res.json();
+                checkedUsername = u;
+                isAdmin = data.status === "admin";
+            }
+            
+            if (isAdmin && passInp.style.display === "none") {
+                passInp.style.display = "block";
+                passInp.required = true;
+                passInp.focus();
+                return;
+            }
+            
+            const payload = { u: u };
+            if (isAdmin) payload.password = passInp.value;
+            
+            const loginRes = await fetch("/api/login", {
+                method: "POST",
+                headers: { "Content-Type": "application/json" },
+                body: JSON.stringify(payload)
+            });
+            const loginData = await loginRes.json();
+            if (loginRes.ok) {
+                window.location.href = "/?u=" + encodeURIComponent(u);
+            } else {
+                errSpan.innerText = loginData.error;
+            }
+        });
+        
+        userInp.addEventListener("input", () => {
+            if (userInp.value.trim() !== checkedUsername) {
+                passInp.style.display = "none";
+                passInp.required = false;
+                passInp.value = "";
+                errSpan.innerText = "";
+            }
+        });
+    </script>
     <script src="js/room_info.js"></script>
 </body>
-
 </html>