app.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. """
  2. Gradio UI for Qwen3-Coder-Next Uncensored Fine-Tuning
  3. Provides a web interface to configure, launch, and monitor training.
  4. """
  5. import os
  6. import json
  7. import time
  8. import threading
  9. from pathlib import Path
  10. import gradio as gr
  11. import yaml
  12. STATUS_FILE = "/tmp/training_status.json"
  13. LOG_FILE = "/tmp/training.log"
  14. # Track background training thread
  15. training_thread = None
  16. training_active = False
  17. def load_config():
  18. with open("config.yaml") as f:
  19. return yaml.safe_load(f)
  20. def read_status():
  21. try:
  22. return json.loads(Path(STATUS_FILE).read_text())
  23. except Exception:
  24. return {
  25. "status": "idle",
  26. "detail": "Ready to start training",
  27. "progress": 0.0,
  28. "metrics": {},
  29. }
  30. def read_logs(num_lines: int = 100):
  31. try:
  32. lines = Path(LOG_FILE).read_text().strip().split("\n")
  33. return "\n".join(lines[-num_lines:])
  34. except Exception:
  35. return "No logs yet."
  36. def get_gpu_info():
  37. try:
  38. import torch
  39. if torch.cuda.is_available():
  40. props = torch.cuda.get_device_properties(0)
  41. mem_total = props.total_mem / (1024**3)
  42. mem_used = torch.cuda.memory_allocated(0) / (1024**3)
  43. mem_reserved = torch.cuda.memory_reserved(0) / (1024**3)
  44. return (
  45. f"GPU: {props.name}\n"
  46. f"VRAM: {mem_total:.1f} GB total | {mem_used:.1f} GB used | {mem_reserved:.1f} GB reserved\n"
  47. f"CUDA: {torch.version.cuda}\n"
  48. f"Compute Capability: {props.major}.{props.minor}"
  49. )
  50. return "No GPU detected"
  51. except Exception as e:
  52. return f"GPU info unavailable: {e}"
  53. # ---------------------------------------------------------------------------
  54. # Training launch
  55. # ---------------------------------------------------------------------------
  56. def start_training(
  57. method,
  58. dataset_choice,
  59. custom_dataset,
  60. hub_model_id,
  61. max_samples,
  62. num_epochs,
  63. learning_rate,
  64. lora_r,
  65. lora_alpha,
  66. batch_size,
  67. grad_accum,
  68. max_seq_length,
  69. system_prompt,
  70. ):
  71. global training_thread, training_active
  72. if training_active:
  73. return "⚠️ Training is already in progress! Wait for it to finish or restart the Space."
  74. if not hub_model_id.strip():
  75. return (
  76. "❌ Hub Model ID is required (e.g., your-username/qwen3-coder-uncensored)"
  77. )
  78. if not os.environ.get("HF_TOKEN"):
  79. return (
  80. "❌ HF_TOKEN secret not set! Go to Space Settings → Secrets → Add HF_TOKEN"
  81. )
  82. # Handle custom dataset upload
  83. custom_path = None
  84. if custom_dataset is not None:
  85. custom_path = (
  86. custom_dataset.name
  87. if hasattr(custom_dataset, "name")
  88. else str(custom_dataset)
  89. )
  90. max_samples_int = int(max_samples) if max_samples and int(max_samples) > 0 else None
  91. def run_training():
  92. global training_active
  93. training_active = True
  94. try:
  95. from train import train as run_train, abliterate
  96. if method == "QLoRA Fine-Tuning":
  97. run_train(
  98. dataset_choice=dataset_choice,
  99. hub_model_id=hub_model_id.strip(),
  100. max_samples=max_samples_int,
  101. custom_dataset_path=custom_path,
  102. num_epochs=int(num_epochs),
  103. learning_rate=float(learning_rate),
  104. lora_r=int(lora_r),
  105. lora_alpha=int(lora_alpha),
  106. batch_size=int(batch_size),
  107. grad_accum=int(grad_accum),
  108. max_seq_length=int(max_seq_length),
  109. system_prompt=system_prompt,
  110. )
  111. elif method == "Abliteration (No Training)":
  112. abliterate(hub_model_id=hub_model_id.strip())
  113. except Exception as e:
  114. status_data = {
  115. "status": "error",
  116. "detail": str(e),
  117. "progress": 0.0,
  118. "metrics": {},
  119. }
  120. Path(STATUS_FILE).write_text(json.dumps(status_data))
  121. finally:
  122. training_active = False
  123. training_thread = threading.Thread(target=run_training, daemon=True)
  124. training_thread.start()
  125. return "🚀 Training launched! Monitor progress below."
  126. def start_merge(hub_model_id_merge):
  127. global training_active
  128. if training_active:
  129. return "⚠️ Another process is running."
  130. if not hub_model_id_merge.strip():
  131. return "❌ Hub Model ID for merged model is required"
  132. def run_merge():
  133. global training_active
  134. training_active = True
  135. try:
  136. from scripts.merge_and_push import merge_and_push
  137. merge_and_push(hub_model_id=hub_model_id_merge.strip())
  138. except Exception as e:
  139. Path(STATUS_FILE).write_text(
  140. json.dumps(
  141. {
  142. "status": "error",
  143. "detail": str(e),
  144. "progress": 0.0,
  145. "metrics": {},
  146. }
  147. )
  148. )
  149. finally:
  150. training_active = False
  151. threading.Thread(target=run_merge, daemon=True).start()
  152. return "🔀 Merge started! This will take a while for an 80B model."
  153. def poll_status():
  154. """Returns updated status for the UI."""
  155. s = read_status()
  156. status_emoji = {
  157. "idle": "⏸️",
  158. "initializing": "⚙️",
  159. "loading_model": "📥",
  160. "training": "🏋️",
  161. "saving": "💾",
  162. "saving_checkpoint": "💾",
  163. "pushing": "🚀",
  164. "merging": "🔀",
  165. "abliterating": "✂️",
  166. "completed": "✅",
  167. "error": "❌",
  168. }.get(s["status"], "❓")
  169. status_text = f"{status_emoji} **{s['status'].upper()}**: {s['detail']}"
  170. progress = s["progress"]
  171. metrics_text = ""
  172. if s.get("metrics"):
  173. m = s["metrics"]
  174. parts = []
  175. if "step" in m:
  176. parts.append(f"Step: {m['step']}/{m.get('total_steps', '?')}")
  177. if "epoch" in m:
  178. parts.append(f"Epoch: {m['epoch']}")
  179. if "loss" in m:
  180. parts.append(f"Loss: {m['loss']}")
  181. if "learning_rate" in m:
  182. parts.append(f"LR: {m['learning_rate']:.2e}")
  183. if "grad_norm" in m:
  184. parts.append(f"Grad Norm: {m['grad_norm']}")
  185. if "train_loss" in m:
  186. parts.append(f"Final Loss: {m['train_loss']}")
  187. if "train_runtime" in m:
  188. parts.append(f"Runtime: {m['train_runtime']}s")
  189. metrics_text = " | ".join(parts)
  190. logs = read_logs(50)
  191. gpu = get_gpu_info()
  192. return status_text, progress, metrics_text, logs, gpu
  193. # ---------------------------------------------------------------------------
  194. # Gradio UI
  195. # ---------------------------------------------------------------------------
  196. def build_ui():
  197. config = load_config()
  198. dataset_names = list(config["datasets"].keys())
  199. with gr.Blocks(
  200. title="Qwen3-Coder-Next Uncensored Fine-Tuner",
  201. theme=gr.themes.Soft(primary_hue="red", secondary_hue="orange"),
  202. css="""
  203. .main-title { text-align: center; margin-bottom: 0; }
  204. .subtitle { text-align: center; color: #888; margin-top: 0; }
  205. """,
  206. ) as app:
  207. gr.Markdown(
  208. "# 🔥 Qwen3-Coder-Next Uncensored Fine-Tuner", elem_classes="main-title"
  209. )
  210. gr.Markdown(
  211. "*QLoRA fine-tuning & abliteration for Qwen3-Coder-Next (80B MoE / 3B active)*",
  212. elem_classes="subtitle",
  213. )
  214. with gr.Tabs():
  215. # ==================================================================
  216. # TAB 1: Training Configuration
  217. # ==================================================================
  218. with gr.Tab("🎯 Train"):
  219. with gr.Row():
  220. with gr.Column(scale=1):
  221. gr.Markdown("### Method")
  222. method = gr.Radio(
  223. ["QLoRA Fine-Tuning", "Abliteration (No Training)"],
  224. value="QLoRA Fine-Tuning",
  225. label="Uncensoring Method",
  226. )
  227. gr.Markdown("### Hub Settings")
  228. hub_model_id = gr.Textbox(
  229. label="Hub Model ID",
  230. placeholder="your-username/qwen3-coder-uncensored",
  231. info="Where to push the trained model on Hugging Face",
  232. )
  233. gr.Markdown("### Dataset")
  234. dataset_choice = gr.Dropdown(
  235. choices=dataset_names,
  236. value=dataset_names[0],
  237. label="Pre-built Dataset",
  238. info="Choose an uncensored dataset",
  239. )
  240. custom_dataset = gr.File(
  241. label="Or Upload Custom Dataset (JSON/CSV/Parquet)",
  242. file_types=[".json", ".jsonl", ".csv", ".parquet"],
  243. )
  244. max_samples = gr.Number(
  245. label="Max Samples (0 = use all)",
  246. value=0,
  247. info="Limit dataset size for faster experiments",
  248. )
  249. gr.Markdown("### System Prompt")
  250. system_prompt = gr.Textbox(
  251. label="System Prompt",
  252. value=config.get("system_prompt", ""),
  253. lines=3,
  254. info="Embedded in every training sample",
  255. )
  256. with gr.Column(scale=1):
  257. gr.Markdown("### Training Hyperparameters")
  258. num_epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs")
  259. learning_rate = gr.Number(label="Learning Rate", value=2e-4)
  260. batch_size = gr.Slider(
  261. 1, 8, value=1, step=1, label="Batch Size per Device"
  262. )
  263. grad_accum = gr.Slider(
  264. 1, 64, value=16, step=1, label="Gradient Accumulation Steps"
  265. )
  266. max_seq_length = gr.Slider(
  267. 512, 8192, value=2048, step=256, label="Max Sequence Length"
  268. )
  269. gr.Markdown("### LoRA Configuration")
  270. lora_r = gr.Slider(
  271. 8, 256, value=64, step=8, label="LoRA Rank (r)"
  272. )
  273. lora_alpha = gr.Slider(
  274. 16, 512, value=128, step=16, label="LoRA Alpha"
  275. )
  276. with gr.Row():
  277. train_btn = gr.Button(
  278. "🚀 Start Training", variant="primary", size="lg"
  279. )
  280. output_msg = gr.Textbox(label="Status", interactive=False)
  281. train_btn.click(
  282. fn=start_training,
  283. inputs=[
  284. method,
  285. dataset_choice,
  286. custom_dataset,
  287. hub_model_id,
  288. max_samples,
  289. num_epochs,
  290. learning_rate,
  291. lora_r,
  292. lora_alpha,
  293. batch_size,
  294. grad_accum,
  295. max_seq_length,
  296. system_prompt,
  297. ],
  298. outputs=output_msg,
  299. )
  300. # ==================================================================
  301. # TAB 2: Monitoring
  302. # ==================================================================
  303. with gr.Tab("📊 Monitor"):
  304. with gr.Row():
  305. status_text = gr.Markdown("⏸️ **IDLE**: Ready to start training")
  306. refresh_btn = gr.Button("🔄 Refresh", size="sm")
  307. progress_bar = gr.Slider(
  308. 0, 1, value=0, label="Progress", interactive=False
  309. )
  310. metrics_display = gr.Textbox(
  311. label="Training Metrics", interactive=False, lines=2
  312. )
  313. gpu_info = gr.Textbox(label="GPU Info", interactive=False, lines=4)
  314. log_display = gr.Textbox(
  315. label="Training Logs (last 50 lines)", interactive=False, lines=20
  316. )
  317. refresh_btn.click(
  318. fn=poll_status,
  319. outputs=[
  320. status_text,
  321. progress_bar,
  322. metrics_display,
  323. log_display,
  324. gpu_info,
  325. ],
  326. )
  327. # Auto-refresh every 10 seconds
  328. app.load(
  329. fn=poll_status,
  330. outputs=[
  331. status_text,
  332. progress_bar,
  333. metrics_display,
  334. log_display,
  335. gpu_info,
  336. ],
  337. every=10,
  338. )
  339. # ==================================================================
  340. # TAB 3: Merge & Push
  341. # ==================================================================
  342. with gr.Tab("🔀 Merge LoRA"):
  343. gr.Markdown("""
  344. ### Merge LoRA Adapter into Base Model
  345. After training completes, use this to merge the LoRA adapter into the base model
  346. and push a standalone model to the Hub.
  347. **⚠️ Warning**: This loads the full 80B model in bfloat16. Requires ~160GB RAM.
  348. For the A100 80GB, this may require CPU offloading.
  349. """)
  350. hub_model_id_merge = gr.Textbox(
  351. label="Hub Model ID for Merged Model",
  352. placeholder="your-username/qwen3-coder-uncensored-merged",
  353. info="This should be DIFFERENT from the adapter model ID",
  354. )
  355. merge_btn = gr.Button("🔀 Merge & Push", variant="primary")
  356. merge_output = gr.Textbox(label="Status", interactive=False)
  357. merge_btn.click(
  358. fn=start_merge,
  359. inputs=[hub_model_id_merge],
  360. outputs=merge_output,
  361. )
  362. # ==================================================================
  363. # TAB 4: Help
  364. # ==================================================================
  365. with gr.Tab("❓ Help"):
  366. gr.Markdown("""
  367. ## Quick Start Guide
  368. ### 1. Set Your HF Token
  369. Go to **Space Settings → Secrets** and add:
  370. - `HF_TOKEN`: Your Hugging Face token with **write** access
  371. - `WANDB_API_KEY` (optional): For training metrics dashboard
  372. ### 2. Choose Your Method
  373. | Method | Pros | Cons |
  374. |--------|------|------|
  375. | **QLoRA Fine-Tuning** | More control, better results, uses dataset | Takes hours, needs GPU |
  376. | **Abliteration** | Fast (~30 min), no dataset needed | Less precise, may affect model quality |
  377. ### 3. QLoRA Fine-Tuning
  378. 1. Select a dataset (or upload your own)
  379. 2. Set your Hub Model ID (e.g., `your-username/qwen3-uncensored`)
  380. 3. Adjust hyperparameters if needed
  381. 4. Click **Start Training**
  382. 5. Monitor progress in the **Monitor** tab
  383. 6. When done, optionally merge LoRA in the **Merge** tab
  384. ### 4. Custom Dataset Format
  385. Your JSON/JSONL should have these fields:
  386. ```json
  387. {
  388. "instruction": "User's question/request",
  389. "output": "Uncensored response"
  390. }
  391. ```
  392. Also supported fields: `prompt`/`response`, `input`/`answer`, `user`/`assistant`
  393. ### 5. Recommended Settings
  394. - **First run**: Use `max_samples=1000` to test the pipeline quickly
  395. - **Full training**: Set `max_samples=0` (all data), `epochs=2-3`
  396. - **Better quality**: Increase `lora_r` to 128, `lora_alpha` to 256
  397. - **Faster training**: Decrease `max_seq_length` to 1024
  398. ### 6. After Training
  399. The LoRA adapter is automatically pushed to your Hub repo.
  400. You can:
  401. - **Use the adapter directly** with PEFT (lightweight)
  402. - **Merge into base model** using the Merge tab (standalone model)
  403. ### Using the Adapter
  404. ```python
  405. from transformers import AutoModelForCausalLM, AutoTokenizer
  406. from peft import PeftModel
  407. base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Coder-Next", device_map="auto")
  408. model = PeftModel.from_pretrained(base, "your-username/qwen3-uncensored")
  409. tokenizer = AutoTokenizer.from_pretrained("your-username/qwen3-uncensored")
  410. ```
  411. """)
  412. return app
  413. if __name__ == "__main__":
  414. app = build_ui()
  415. app.launch(
  416. server_name="0.0.0.0",
  417. server_port=7860,
  418. share=False,
  419. )