app.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. """
  2. Gradio UI for Qwen3-Coder-Next Uncensored Fine-Tuning
  3. Provides a web interface to configure, launch, and monitor training.
  4. """
  5. import os
  6. import json
  7. import time
  8. import threading
  9. from pathlib import Path
  10. import gradio as gr
  11. import yaml
  12. STATUS_FILE = "/home/user/training_status.json"
  13. LOG_FILE = "/home/user/training.log"
  14. # Track background training thread
  15. training_thread = None
  16. training_active = False
  17. def load_config():
  18. with open("config.yaml") as f:
  19. return yaml.safe_load(f)
  20. def read_status():
  21. try:
  22. return json.loads(Path(STATUS_FILE).read_text())
  23. except Exception:
  24. return {
  25. "status": "idle",
  26. "detail": "Ready to start training",
  27. "progress": 0.0,
  28. "metrics": {},
  29. }
  30. def read_logs(num_lines: int = 100):
  31. try:
  32. lines = Path(LOG_FILE).read_text().strip().split("\n")
  33. return "\n".join(lines[-num_lines:])
  34. except Exception:
  35. return "No logs yet."
  36. def get_gpu_info():
  37. try:
  38. import torch
  39. if torch.cuda.is_available():
  40. props = torch.cuda.get_device_properties(0)
  41. mem_total = props.total_mem / (1024**3)
  42. mem_used = torch.cuda.memory_allocated(0) / (1024**3)
  43. mem_reserved = torch.cuda.memory_reserved(0) / (1024**3)
  44. return (
  45. f"GPU: {props.name}\n"
  46. f"VRAM: {mem_total:.1f} GB total | {mem_used:.1f} GB used | {mem_reserved:.1f} GB reserved\n"
  47. f"CUDA: {torch.version.cuda}\n"
  48. f"Compute Capability: {props.major}.{props.minor}"
  49. )
  50. return "No GPU detected"
  51. except Exception as e:
  52. return f"GPU info unavailable: {e}"
  53. # ---------------------------------------------------------------------------
  54. # Training launch
  55. # ---------------------------------------------------------------------------
  56. def start_training(
  57. method,
  58. dataset_choice,
  59. custom_dataset,
  60. hub_model_id,
  61. max_samples,
  62. num_epochs,
  63. learning_rate,
  64. lora_r,
  65. lora_alpha,
  66. batch_size,
  67. grad_accum,
  68. max_seq_length,
  69. system_prompt,
  70. ):
  71. global training_thread, training_active
  72. if training_active:
  73. return "⚠️ Training is already in progress! Wait for it to finish or restart the Space."
  74. if not hub_model_id.strip():
  75. return (
  76. "❌ Hub Model ID is required (e.g., your-username/qwen3-coder-uncensored)"
  77. )
  78. if not os.environ.get("HF_TOKEN"):
  79. return (
  80. "❌ HF_TOKEN secret not set! Go to Space Settings → Secrets → Add HF_TOKEN"
  81. )
  82. # Handle custom dataset upload
  83. custom_path = None
  84. if custom_dataset is not None:
  85. custom_path = (
  86. custom_dataset.name
  87. if hasattr(custom_dataset, "name")
  88. else str(custom_dataset)
  89. )
  90. max_samples_int = int(max_samples) if max_samples and int(max_samples) > 0 else None
  91. def run_training():
  92. global training_active
  93. training_active = True
  94. try:
  95. from train import train as run_train, abliterate
  96. if method == "QLoRA Fine-Tuning":
  97. run_train(
  98. dataset_choice=dataset_choice,
  99. hub_model_id=hub_model_id.strip(),
  100. max_samples=max_samples_int,
  101. custom_dataset_path=custom_path,
  102. num_epochs=int(num_epochs),
  103. learning_rate=float(learning_rate),
  104. lora_r=int(lora_r),
  105. lora_alpha=int(lora_alpha),
  106. batch_size=int(batch_size),
  107. grad_accum=int(grad_accum),
  108. max_seq_length=int(max_seq_length),
  109. system_prompt=system_prompt,
  110. )
  111. elif method == "Abliteration (No Training)":
  112. abliterate(hub_model_id=hub_model_id.strip())
  113. except Exception as e:
  114. status_data = {
  115. "status": "error",
  116. "detail": str(e),
  117. "progress": 0.0,
  118. "metrics": {},
  119. }
  120. Path(STATUS_FILE).write_text(json.dumps(status_data))
  121. finally:
  122. training_active = False
  123. training_thread = threading.Thread(target=run_training, daemon=True)
  124. training_thread.start()
  125. return "🚀 Training launched! Monitor progress below."
  126. def start_merge(hub_model_id_merge):
  127. global training_active
  128. if training_active:
  129. return "⚠️ Another process is running."
  130. if not hub_model_id_merge.strip():
  131. return "❌ Hub Model ID for merged model is required"
  132. def run_merge():
  133. global training_active
  134. training_active = True
  135. try:
  136. from scripts.merge_and_push import merge_and_push
  137. merge_and_push(hub_model_id=hub_model_id_merge.strip())
  138. except Exception as e:
  139. Path(STATUS_FILE).write_text(
  140. json.dumps(
  141. {
  142. "status": "error",
  143. "detail": str(e),
  144. "progress": 0.0,
  145. "metrics": {},
  146. }
  147. )
  148. )
  149. finally:
  150. training_active = False
  151. threading.Thread(target=run_merge, daemon=True).start()
  152. return "🔀 Merge started! This will take a while for an 80B model."
  153. def poll_status():
  154. """Returns updated status for the UI."""
  155. s = read_status()
  156. status_emoji = {
  157. "idle": "⏸️",
  158. "initializing": "⚙️",
  159. "loading_model": "📥",
  160. "training": "🏋️",
  161. "saving": "💾",
  162. "saving_checkpoint": "💾",
  163. "pushing": "🚀",
  164. "merging": "🔀",
  165. "abliterating": "✂️",
  166. "completed": "✅",
  167. "error": "❌",
  168. }.get(s["status"], "❓")
  169. status_text = f"{status_emoji} **{s['status'].upper()}**: {s['detail']}"
  170. progress = s["progress"]
  171. metrics_text = ""
  172. if s.get("metrics"):
  173. m = s["metrics"]
  174. parts = []
  175. if "step" in m:
  176. parts.append(f"Step: {m['step']}/{m.get('total_steps', '?')}")
  177. if "epoch" in m:
  178. parts.append(f"Epoch: {m['epoch']}")
  179. if "loss" in m:
  180. parts.append(f"Loss: {m['loss']}")
  181. if "learning_rate" in m:
  182. parts.append(f"LR: {m['learning_rate']:.2e}")
  183. if "grad_norm" in m:
  184. parts.append(f"Grad Norm: {m['grad_norm']}")
  185. if "train_loss" in m:
  186. parts.append(f"Final Loss: {m['train_loss']}")
  187. if "train_runtime" in m:
  188. parts.append(f"Runtime: {m['train_runtime']}s")
  189. metrics_text = " | ".join(parts)
  190. logs = read_logs(50)
  191. gpu = get_gpu_info()
  192. return status_text, progress, metrics_text, logs, gpu
  193. # ---------------------------------------------------------------------------
  194. # Gradio UI
  195. # ---------------------------------------------------------------------------
  196. def build_ui():
  197. config = load_config()
  198. dataset_names = list(config["datasets"].keys())
  199. with gr.Blocks(
  200. title="Qwen3-Coder-Next Uncensored Fine-Tuner",
  201. ) as app:
  202. gr.Markdown("# Qwen3-Coder-Next Uncensored Fine-Tuner")
  203. gr.Markdown(
  204. "*QLoRA fine-tuning & abliteration for Qwen3-Coder-Next (80B MoE / 3B active)*",
  205. )
  206. with gr.Tabs():
  207. # ==================================================================
  208. # TAB 1: Training Configuration
  209. # ==================================================================
  210. with gr.Tab("🎯 Train"):
  211. with gr.Row():
  212. with gr.Column(scale=1):
  213. gr.Markdown("### Method")
  214. method = gr.Radio(
  215. ["QLoRA Fine-Tuning", "Abliteration (No Training)"],
  216. value="QLoRA Fine-Tuning",
  217. label="Uncensoring Method",
  218. )
  219. gr.Markdown("### Hub Settings")
  220. hub_model_id = gr.Textbox(
  221. label="Hub Model ID",
  222. placeholder="your-username/qwen3-coder-uncensored",
  223. info="Where to push the trained model on Hugging Face",
  224. )
  225. gr.Markdown("### Dataset")
  226. dataset_choice = gr.Dropdown(
  227. choices=dataset_names,
  228. value=dataset_names[0],
  229. label="Pre-built Dataset",
  230. info="Choose an uncensored dataset",
  231. )
  232. custom_dataset = gr.File(
  233. label="Or Upload Custom Dataset (JSON/CSV/Parquet)",
  234. file_types=[".json", ".jsonl", ".csv", ".parquet"],
  235. )
  236. max_samples = gr.Number(
  237. label="Max Samples (0 = use all)",
  238. value=0,
  239. info="Limit dataset size for faster experiments",
  240. )
  241. gr.Markdown("### System Prompt")
  242. system_prompt = gr.Textbox(
  243. label="System Prompt",
  244. value=config.get("system_prompt", ""),
  245. lines=3,
  246. info="Embedded in every training sample",
  247. )
  248. with gr.Column(scale=1):
  249. gr.Markdown("### Training Hyperparameters")
  250. num_epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs")
  251. learning_rate = gr.Number(label="Learning Rate", value=2e-4)
  252. batch_size = gr.Slider(
  253. 1, 8, value=1, step=1, label="Batch Size per Device"
  254. )
  255. grad_accum = gr.Slider(
  256. 1, 64, value=16, step=1, label="Gradient Accumulation Steps"
  257. )
  258. max_seq_length = gr.Slider(
  259. 512, 8192, value=2048, step=256, label="Max Sequence Length"
  260. )
  261. gr.Markdown("### LoRA Configuration")
  262. lora_r = gr.Slider(
  263. 8, 256, value=64, step=8, label="LoRA Rank (r)"
  264. )
  265. lora_alpha = gr.Slider(
  266. 16, 512, value=128, step=16, label="LoRA Alpha"
  267. )
  268. with gr.Row():
  269. train_btn = gr.Button(
  270. "🚀 Start Training", variant="primary", size="lg"
  271. )
  272. output_msg = gr.Textbox(label="Status", interactive=False)
  273. train_btn.click(
  274. fn=start_training,
  275. inputs=[
  276. method,
  277. dataset_choice,
  278. custom_dataset,
  279. hub_model_id,
  280. max_samples,
  281. num_epochs,
  282. learning_rate,
  283. lora_r,
  284. lora_alpha,
  285. batch_size,
  286. grad_accum,
  287. max_seq_length,
  288. system_prompt,
  289. ],
  290. outputs=output_msg,
  291. )
  292. # ==================================================================
  293. # TAB 2: Monitoring
  294. # ==================================================================
  295. with gr.Tab("📊 Monitor"):
  296. with gr.Row():
  297. status_text = gr.Markdown("⏸️ **IDLE**: Ready to start training")
  298. refresh_btn = gr.Button("🔄 Refresh", size="sm")
  299. progress_bar = gr.Slider(
  300. 0, 1, value=0, label="Progress", interactive=False
  301. )
  302. metrics_display = gr.Textbox(
  303. label="Training Metrics", interactive=False, lines=2
  304. )
  305. gpu_info = gr.Textbox(label="GPU Info", interactive=False, lines=4)
  306. log_display = gr.Textbox(
  307. label="Training Logs (last 50 lines)", interactive=False, lines=20
  308. )
  309. refresh_btn.click(
  310. fn=poll_status,
  311. outputs=[
  312. status_text,
  313. progress_bar,
  314. metrics_display,
  315. log_display,
  316. gpu_info,
  317. ],
  318. )
  319. # Auto-refresh every 10 seconds using Timer
  320. timer = gr.Timer(value=10)
  321. timer.tick(
  322. fn=poll_status,
  323. outputs=[
  324. status_text,
  325. progress_bar,
  326. metrics_display,
  327. log_display,
  328. gpu_info,
  329. ],
  330. )
  331. # ==================================================================
  332. # TAB 3: Merge & Push
  333. # ==================================================================
  334. with gr.Tab("🔀 Merge LoRA"):
  335. gr.Markdown("""
  336. ### Merge LoRA Adapter into Base Model
  337. After training completes, use this to merge the LoRA adapter into the base model
  338. and push a standalone model to the Hub.
  339. **⚠️ Warning**: This loads the full 80B model in bfloat16. Requires ~160GB RAM.
  340. For the A100 80GB, this may require CPU offloading.
  341. """)
  342. hub_model_id_merge = gr.Textbox(
  343. label="Hub Model ID for Merged Model",
  344. placeholder="your-username/qwen3-coder-uncensored-merged",
  345. info="This should be DIFFERENT from the adapter model ID",
  346. )
  347. merge_btn = gr.Button("🔀 Merge & Push", variant="primary")
  348. merge_output = gr.Textbox(label="Status", interactive=False)
  349. merge_btn.click(
  350. fn=start_merge,
  351. inputs=[hub_model_id_merge],
  352. outputs=merge_output,
  353. )
  354. # ==================================================================
  355. # TAB 4: Help
  356. # ==================================================================
  357. with gr.Tab("❓ Help"):
  358. gr.Markdown("""
  359. ## Quick Start Guide
  360. ### 1. Set Your HF Token
  361. Go to **Space Settings → Secrets** and add:
  362. - `HF_TOKEN`: Your Hugging Face token with **write** access
  363. - `WANDB_API_KEY` (optional): For training metrics dashboard
  364. ### 2. Choose Your Method
  365. | Method | Pros | Cons |
  366. |--------|------|------|
  367. | **QLoRA Fine-Tuning** | More control, better results, uses dataset | Takes hours, needs GPU |
  368. | **Abliteration** | Fast (~30 min), no dataset needed | Less precise, may affect model quality |
  369. ### 3. QLoRA Fine-Tuning
  370. 1. Select a dataset (or upload your own)
  371. 2. Set your Hub Model ID (e.g., `your-username/qwen3-uncensored`)
  372. 3. Adjust hyperparameters if needed
  373. 4. Click **Start Training**
  374. 5. Monitor progress in the **Monitor** tab
  375. 6. When done, optionally merge LoRA in the **Merge** tab
  376. ### 4. Custom Dataset Format
  377. Your JSON/JSONL should have these fields:
  378. ```json
  379. {
  380. "instruction": "User's question/request",
  381. "output": "Uncensored response"
  382. }
  383. ```
  384. Also supported fields: `prompt`/`response`, `input`/`answer`, `user`/`assistant`
  385. ### 5. Recommended Settings
  386. - **First run**: Use `max_samples=1000` to test the pipeline quickly
  387. - **Full training**: Set `max_samples=0` (all data), `epochs=2-3`
  388. - **Better quality**: Increase `lora_r` to 128, `lora_alpha` to 256
  389. - **Faster training**: Decrease `max_seq_length` to 1024
  390. ### 6. After Training
  391. The LoRA adapter is automatically pushed to your Hub repo.
  392. You can:
  393. - **Use the adapter directly** with PEFT (lightweight)
  394. - **Merge into base model** using the Merge tab (standalone model)
  395. ### Using the Adapter
  396. ```python
  397. from transformers import AutoModelForCausalLM, AutoTokenizer
  398. from peft import PeftModel
  399. base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Coder-Next", device_map="auto")
  400. model = PeftModel.from_pretrained(base, "your-username/qwen3-uncensored")
  401. tokenizer = AutoTokenizer.from_pretrained("your-username/qwen3-uncensored")
  402. ```
  403. """)
  404. return app
  405. if __name__ == "__main__":
  406. app = build_ui()
  407. app.launch(
  408. server_name="0.0.0.0",
  409. server_port=7860,
  410. share=False,
  411. theme=gr.themes.Soft(primary_hue="red", secondary_hue="orange"),
  412. )