train.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. """
  2. Qwen3-Coder-Next Uncensored Fine-Tuning Script
  3. QLoRA 4-bit fine-tuning with TRL's SFTTrainer
  4. """
  5. import os
  6. import sys
  7. import json
  8. import yaml
  9. import torch
  10. import logging
  11. from pathlib import Path
  12. from typing import Optional
  13. from dataclasses import dataclass
  14. from transformers import (
  15. AutoModelForCausalLM,
  16. AutoTokenizer,
  17. BitsAndBytesConfig,
  18. TrainerCallback,
  19. )
  20. from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
  21. from trl import SFTTrainer, SFTConfig
  22. from datasets import load_dataset, Dataset, concatenate_datasets
  23. from huggingface_hub import HfApi
  24. logging.basicConfig(
  25. level=logging.INFO,
  26. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  27. handlers=[
  28. logging.StreamHandler(sys.stdout),
  29. logging.FileHandler("/home/user/training.log", mode="a"),
  30. ],
  31. )
  32. logger = logging.getLogger(__name__)
  33. # ---------------------------------------------------------------------------
  34. # Status file for Gradio UI to poll
  35. # ---------------------------------------------------------------------------
  36. STATUS_FILE = "/home/user/training_status.json"
  37. def write_status(
  38. status: str, detail: str = "", progress: float = 0.0, metrics: dict | None = None
  39. ):
  40. data = {
  41. "status": status,
  42. "detail": detail,
  43. "progress": progress,
  44. "metrics": metrics or {},
  45. }
  46. Path(STATUS_FILE).write_text(json.dumps(data))
  47. class StatusCallback(TrainerCallback):
  48. """Streams training metrics back to the Gradio UI via a JSON status file."""
  49. def __init__(self, total_steps: int):
  50. self.total_steps = max(total_steps, 1)
  51. def on_log(self, args, state, control, logs=None, **kwargs):
  52. if logs is None:
  53. return
  54. progress = min(state.global_step / self.total_steps, 1.0)
  55. metrics = {
  56. "step": state.global_step,
  57. "total_steps": self.total_steps,
  58. "epoch": round(state.epoch or 0, 2),
  59. "loss": round(logs.get("loss", 0), 4),
  60. "learning_rate": logs.get("learning_rate", 0),
  61. "grad_norm": round(logs.get("grad_norm", 0), 4),
  62. }
  63. write_status(
  64. "training",
  65. f"Step {state.global_step}/{self.total_steps}",
  66. progress,
  67. metrics,
  68. )
  69. def on_save(self, args, state, control, **kwargs):
  70. write_status(
  71. "saving_checkpoint", f"Saved checkpoint at step {state.global_step}"
  72. )
  73. def on_train_end(self, args, state, control, **kwargs):
  74. write_status("completed", "Training finished!", 1.0)
  75. # ---------------------------------------------------------------------------
  76. # Config helpers
  77. # ---------------------------------------------------------------------------
  78. def load_config(config_path: str = "config.yaml") -> dict:
  79. with open(config_path) as f:
  80. return yaml.safe_load(f)
  81. # ---------------------------------------------------------------------------
  82. # Dataset preparation
  83. # ---------------------------------------------------------------------------
  84. def prepare_dataset(
  85. dataset_name: str,
  86. config: dict,
  87. tokenizer,
  88. system_prompt: str,
  89. max_samples: Optional[int] = None,
  90. custom_dataset_path: Optional[str] = None,
  91. ) -> Dataset:
  92. """Load and format dataset into chat-template strings."""
  93. if custom_dataset_path:
  94. logger.info(f"Loading custom dataset from {custom_dataset_path}")
  95. if custom_dataset_path.endswith(".json") or custom_dataset_path.endswith(
  96. ".jsonl"
  97. ):
  98. ds = load_dataset("json", data_files=custom_dataset_path, split="train")
  99. elif custom_dataset_path.endswith(".csv"):
  100. ds = load_dataset("csv", data_files=custom_dataset_path, split="train")
  101. elif custom_dataset_path.endswith(".parquet"):
  102. ds = load_dataset("parquet", data_files=custom_dataset_path, split="train")
  103. else:
  104. ds = load_dataset(custom_dataset_path, split="train")
  105. # Auto-detect fields
  106. cols = ds.column_names
  107. # Check if it's a conversations-format dataset
  108. if "conversations" in cols:
  109. ds_format = "conversations"
  110. conversations_field = "conversations"
  111. human_key = "human"
  112. assistant_key = "gpt"
  113. instruction_field = None
  114. output_field = None
  115. system_field = None
  116. else:
  117. ds_format = "flat"
  118. conversations_field = None
  119. human_key = None
  120. assistant_key = None
  121. instruction_field = next(
  122. (
  123. c
  124. for c in ["instruction", "prompt", "input", "question", "user"]
  125. if c in cols
  126. ),
  127. cols[0],
  128. )
  129. output_field = next(
  130. (
  131. c
  132. for c in ["output", "response", "answer", "chosen", "assistant"]
  133. if c in cols
  134. ),
  135. cols[1] if len(cols) > 1 else cols[0],
  136. )
  137. system_field = "system" if "system" in cols else None
  138. else:
  139. ds_cfg = config["datasets"].get(dataset_name)
  140. if ds_cfg is None:
  141. raise ValueError(
  142. f"Unknown dataset: {dataset_name}. Available: {list(config['datasets'].keys())}"
  143. )
  144. logger.info(f"Loading dataset: {ds_cfg['name']}")
  145. ds = load_dataset(ds_cfg["name"], split=ds_cfg["split"])
  146. ds_format = ds_cfg.get("format", "flat")
  147. if ds_format == "conversations":
  148. conversations_field = ds_cfg.get("conversations_field", "conversations")
  149. human_key = ds_cfg.get("human_key", "human")
  150. assistant_key = ds_cfg.get("assistant_key", "gpt")
  151. instruction_field = None
  152. output_field = None
  153. system_field = None
  154. else:
  155. conversations_field = None
  156. human_key = None
  157. assistant_key = None
  158. instruction_field = ds_cfg["instruction_field"]
  159. output_field = ds_cfg["output_field"]
  160. system_field = ds_cfg.get("system_field")
  161. if max_samples and max_samples < len(ds):
  162. ds = ds.shuffle(seed=42).select(range(max_samples))
  163. logger.info(f"Dataset loaded: {len(ds)} samples (format: {ds_format})")
  164. if ds_format == "conversations":
  165. def format_chat(example):
  166. """Convert conversations [{from, value}] into chat template."""
  167. convos = example[conversations_field]
  168. messages = []
  169. # Prepend system prompt
  170. if system_prompt:
  171. messages.append({"role": "system", "content": system_prompt})
  172. for turn in convos:
  173. role_key = turn.get("from", "")
  174. value = turn.get("value", "")
  175. if role_key == human_key:
  176. messages.append({"role": "user", "content": value})
  177. elif role_key == assistant_key:
  178. messages.append({"role": "assistant", "content": value})
  179. # Skip samples with no user/assistant turns
  180. if len(messages) < 2 or not any(m["role"] == "assistant" for m in messages):
  181. return {"text": ""}
  182. text = tokenizer.apply_chat_template(
  183. messages, tokenize=False, add_generation_prompt=False
  184. )
  185. return {"text": text}
  186. else:
  187. def format_chat(example):
  188. messages = []
  189. # System prompt
  190. if system_field and example.get(system_field):
  191. messages.append({"role": "system", "content": example[system_field]})
  192. elif system_prompt:
  193. messages.append({"role": "system", "content": system_prompt})
  194. # User message
  195. messages.append(
  196. {"role": "user", "content": str(example[instruction_field])}
  197. )
  198. # Assistant response
  199. messages.append(
  200. {"role": "assistant", "content": str(example[output_field])}
  201. )
  202. text = tokenizer.apply_chat_template(
  203. messages, tokenize=False, add_generation_prompt=False
  204. )
  205. return {"text": text}
  206. ds = ds.map(
  207. format_chat,
  208. remove_columns=ds.column_names,
  209. num_proc=4,
  210. desc="Formatting dataset",
  211. )
  212. # Filter out empty texts (e.g. conversations with no assistant turn)
  213. ds = ds.filter(lambda x: len(x["text"].strip()) > 0)
  214. logger.info(f"After filtering: {len(ds)} samples")
  215. return ds
  216. # ---------------------------------------------------------------------------
  217. # Main training function
  218. # ---------------------------------------------------------------------------
  219. def train(
  220. dataset_choice: str = "wizard_vicuna",
  221. hub_model_id: str = "",
  222. max_samples: Optional[int] = None,
  223. custom_dataset_path: Optional[str] = None,
  224. num_epochs: int = 2,
  225. learning_rate: float = 2e-4,
  226. lora_r: int = 64,
  227. lora_alpha: int = 128,
  228. batch_size: int = 1,
  229. grad_accum: int = 16,
  230. max_seq_length: int = 2048,
  231. system_prompt: str = "",
  232. ):
  233. """Run the full QLoRA fine-tuning pipeline."""
  234. config = load_config()
  235. write_status("initializing", "Loading configuration...")
  236. # Override config with function params
  237. config["training"]["num_train_epochs"] = num_epochs
  238. config["training"]["learning_rate"] = learning_rate
  239. config["training"]["per_device_train_batch_size"] = batch_size
  240. config["training"]["gradient_accumulation_steps"] = grad_accum
  241. config["training"]["max_seq_length"] = max_seq_length
  242. config["lora"]["r"] = lora_r
  243. config["lora"]["lora_alpha"] = lora_alpha
  244. if not system_prompt:
  245. system_prompt = config.get("system_prompt", "")
  246. hf_token = os.environ.get("HF_TOKEN")
  247. if not hf_token:
  248. write_status(
  249. "error", "HF_TOKEN secret not set! Add it in Space Settings → Secrets."
  250. )
  251. raise ValueError("HF_TOKEN environment variable is required")
  252. # -----------------------------------------------------------------------
  253. # 1. Load tokenizer
  254. # -----------------------------------------------------------------------
  255. write_status("initializing", "Loading tokenizer...")
  256. model_name = config["model"]["name"]
  257. tokenizer = AutoTokenizer.from_pretrained(
  258. model_name,
  259. trust_remote_code=config["model"]["trust_remote_code"],
  260. token=hf_token,
  261. )
  262. if tokenizer.pad_token is None:
  263. tokenizer.pad_token = tokenizer.eos_token
  264. tokenizer.padding_side = "right"
  265. # -----------------------------------------------------------------------
  266. # 2. Load dataset
  267. # -----------------------------------------------------------------------
  268. write_status("initializing", "Loading and formatting dataset...")
  269. dataset = prepare_dataset(
  270. dataset_name=dataset_choice,
  271. config=config,
  272. tokenizer=tokenizer,
  273. system_prompt=system_prompt,
  274. max_samples=max_samples,
  275. custom_dataset_path=custom_dataset_path,
  276. )
  277. logger.info(f"Formatted dataset: {len(dataset)} samples")
  278. logger.info(f"Sample:\n{dataset[0]['text'][:500]}...")
  279. # -----------------------------------------------------------------------
  280. # 3. Load model in 4-bit
  281. # -----------------------------------------------------------------------
  282. write_status(
  283. "loading_model",
  284. "Loading model in 4-bit quantization... (this takes a while)",
  285. )
  286. q_cfg = config["quantization"]
  287. bnb_config = BitsAndBytesConfig(
  288. load_in_4bit=q_cfg["load_in_4bit"],
  289. bnb_4bit_quant_type=q_cfg["bnb_4bit_quant_type"],
  290. bnb_4bit_compute_dtype=getattr(torch, q_cfg["bnb_4bit_compute_dtype"]),
  291. bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
  292. )
  293. # Pick best available attention: flash_attention_2 > sdpa > eager
  294. if torch.cuda.is_available():
  295. try:
  296. import flash_attn # noqa: F401
  297. attn_impl = "flash_attention_2"
  298. except ImportError:
  299. attn_impl = "sdpa" # PyTorch native, no extra install needed
  300. else:
  301. attn_impl = "eager"
  302. logger.info(f"Using attention implementation: {attn_impl}")
  303. # Log transformers version to confirm qwen3_next support
  304. import transformers
  305. logger.info(f"transformers version: {transformers.__version__}")
  306. # Pre-quantized fallback model for Qwen3-Next architecture
  307. PRE_QUANTIZED_FALLBACK = "unsloth/Qwen3-Next-80B-A3B-Instruct-bnb-4bit"
  308. # Force ALL layers onto GPU 0. device_map="auto" is too conservative
  309. # with large MoE models and offloads to CPU where bnb 4-bit can't run.
  310. # 80B params in 4-bit ≈ 40GB — fits comfortably on A100 80GB.
  311. try:
  312. logger.info(
  313. f"Attempting to load {model_name} with on-the-fly 4-bit quantization..."
  314. )
  315. model = AutoModelForCausalLM.from_pretrained(
  316. model_name,
  317. quantization_config=bnb_config,
  318. device_map={"": 0},
  319. trust_remote_code=config["model"]["trust_remote_code"],
  320. torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
  321. token=hf_token,
  322. attn_implementation=attn_impl,
  323. )
  324. logger.info(f"Successfully loaded {model_name} with 4-bit quantization")
  325. except Exception as e:
  326. logger.warning(f"On-the-fly 4-bit quantization failed for {model_name}: {e}")
  327. logger.info(f"Falling back to pre-quantized model: {PRE_QUANTIZED_FALLBACK}")
  328. write_status(
  329. "loading_model",
  330. f"On-the-fly quantization failed, loading pre-quantized fallback...",
  331. )
  332. # Pre-quantized model already has bnb 4-bit weights baked in —
  333. # do NOT pass quantization_config again, just load directly.
  334. model = AutoModelForCausalLM.from_pretrained(
  335. PRE_QUANTIZED_FALLBACK,
  336. device_map={"": 0},
  337. trust_remote_code=True,
  338. torch_dtype=torch.bfloat16,
  339. token=hf_token,
  340. attn_implementation=attn_impl,
  341. )
  342. logger.info(
  343. f"Successfully loaded pre-quantized fallback: {PRE_QUANTIZED_FALLBACK}"
  344. )
  345. model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
  346. logger.info("Model loaded and prepared for k-bit training")
  347. # -----------------------------------------------------------------------
  348. # 4. Apply LoRA
  349. # -----------------------------------------------------------------------
  350. write_status("loading_model", "Applying LoRA adapters...")
  351. lora_cfg = config["lora"]
  352. target_modules = lora_cfg["target_modules"]
  353. # target_modules can be a list of strings or the string "all-linear"
  354. if isinstance(target_modules, str) and target_modules != "all-linear":
  355. target_modules = [target_modules]
  356. lora_config = LoraConfig(
  357. r=lora_cfg["r"],
  358. lora_alpha=lora_cfg["lora_alpha"],
  359. target_modules=target_modules,
  360. lora_dropout=lora_cfg["lora_dropout"],
  361. bias=lora_cfg["bias"],
  362. task_type=lora_cfg["task_type"],
  363. )
  364. model = get_peft_model(model, lora_config)
  365. trainable, total = model.get_nb_trainable_parameters()
  366. logger.info(
  367. f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)"
  368. )
  369. write_status(
  370. "loading_model",
  371. f"LoRA applied: {trainable:,} trainable params ({100 * trainable / total:.2f}%)",
  372. )
  373. # -----------------------------------------------------------------------
  374. # 5. Training arguments
  375. # -----------------------------------------------------------------------
  376. t_cfg = config["training"]
  377. output_dir = t_cfg["output_dir"]
  378. # Determine hub settings
  379. push_to_hub = bool(hub_model_id)
  380. hub_cfg = config.get("hub", {})
  381. training_args = SFTConfig(
  382. output_dir=output_dir,
  383. num_train_epochs=t_cfg["num_train_epochs"],
  384. per_device_train_batch_size=t_cfg["per_device_train_batch_size"],
  385. gradient_accumulation_steps=t_cfg["gradient_accumulation_steps"],
  386. learning_rate=t_cfg["learning_rate"],
  387. lr_scheduler_type=t_cfg["lr_scheduler_type"],
  388. warmup_ratio=t_cfg["warmup_ratio"],
  389. weight_decay=t_cfg["weight_decay"],
  390. bf16=t_cfg["bf16"],
  391. tf32=t_cfg.get("tf32", True),
  392. max_grad_norm=t_cfg["max_grad_norm"],
  393. logging_steps=t_cfg["logging_steps"],
  394. save_strategy=t_cfg["save_strategy"],
  395. save_steps=t_cfg["save_steps"],
  396. save_total_limit=t_cfg["save_total_limit"],
  397. max_seq_length=t_cfg["max_seq_length"],
  398. gradient_checkpointing=t_cfg["gradient_checkpointing"],
  399. gradient_checkpointing_kwargs=t_cfg.get(
  400. "gradient_checkpointing_kwargs", {"use_reentrant": False}
  401. ),
  402. optim=t_cfg["optim"],
  403. report_to=t_cfg.get("report_to", "none")
  404. if os.environ.get("WANDB_API_KEY")
  405. else "none",
  406. seed=t_cfg["seed"],
  407. dataloader_num_workers=t_cfg.get("dataloader_num_workers", 4),
  408. dataloader_pin_memory=t_cfg.get("dataloader_pin_memory", True),
  409. packing=t_cfg.get("packing", True),
  410. dataset_text_field="text",
  411. push_to_hub=push_to_hub,
  412. hub_model_id=hub_model_id if push_to_hub else None,
  413. hub_strategy=hub_cfg.get("hub_strategy", "checkpoint"),
  414. hub_private_repo=hub_cfg.get("hub_private_repo", False),
  415. hub_token=hf_token,
  416. )
  417. # -----------------------------------------------------------------------
  418. # 6. Trainer
  419. # -----------------------------------------------------------------------
  420. total_steps = (
  421. len(dataset)
  422. // (t_cfg["per_device_train_batch_size"] * t_cfg["gradient_accumulation_steps"])
  423. * t_cfg["num_train_epochs"]
  424. )
  425. trainer = SFTTrainer(
  426. model=model,
  427. args=training_args,
  428. train_dataset=dataset,
  429. processing_class=tokenizer,
  430. callbacks=[StatusCallback(total_steps)],
  431. )
  432. # -----------------------------------------------------------------------
  433. # 7. Train!
  434. # -----------------------------------------------------------------------
  435. write_status("training", "Starting training...", 0.0)
  436. logger.info("=" * 60)
  437. logger.info("TRAINING STARTED")
  438. logger.info(f" Dataset: {len(dataset)} samples")
  439. logger.info(f" Epochs: {t_cfg['num_train_epochs']}")
  440. logger.info(f" Batch size: {t_cfg['per_device_train_batch_size']}")
  441. logger.info(f" Grad accum: {t_cfg['gradient_accumulation_steps']}")
  442. logger.info(
  443. f" Effective batch: {t_cfg['per_device_train_batch_size'] * t_cfg['gradient_accumulation_steps']}"
  444. )
  445. logger.info(f" LR: {t_cfg['learning_rate']}")
  446. logger.info(f" LoRA r={lora_cfg['r']}, alpha={lora_cfg['lora_alpha']}")
  447. logger.info(f" Max seq length: {t_cfg['max_seq_length']}")
  448. logger.info(f" Total steps: ~{total_steps}")
  449. logger.info(f" Push to hub: {push_to_hub} → {hub_model_id}")
  450. logger.info("=" * 60)
  451. train_result = trainer.train()
  452. # -----------------------------------------------------------------------
  453. # 8. Save final adapter
  454. # -----------------------------------------------------------------------
  455. write_status("saving", "Saving final LoRA adapter...")
  456. final_adapter_path = os.path.join(output_dir, "final_adapter")
  457. trainer.save_model(final_adapter_path)
  458. tokenizer.save_pretrained(final_adapter_path)
  459. # Push adapter to Hub
  460. if push_to_hub and hub_model_id:
  461. write_status("pushing", f"Pushing LoRA adapter to {hub_model_id}...")
  462. api = HfApi(token=hf_token)
  463. api.create_repo(
  464. hub_model_id, exist_ok=True, private=hub_cfg.get("hub_private_repo", False)
  465. )
  466. api.upload_folder(
  467. folder_path=final_adapter_path,
  468. repo_id=hub_model_id,
  469. commit_message="Upload QLoRA adapter — Qwen3-Coder-Next uncensored",
  470. )
  471. logger.info(f"Adapter pushed to https://huggingface.co/{hub_model_id}")
  472. write_status(
  473. "completed",
  474. f"Training complete! Adapter saved to {final_adapter_path}",
  475. 1.0,
  476. {
  477. "train_loss": round(train_result.metrics.get("train_loss", 0), 4),
  478. "train_runtime": round(train_result.metrics.get("train_runtime", 0), 1),
  479. "train_samples_per_second": round(
  480. train_result.metrics.get("train_samples_per_second", 0), 2
  481. ),
  482. },
  483. )
  484. return final_adapter_path
  485. # ---------------------------------------------------------------------------
  486. # Abliteration (no training needed)
  487. # ---------------------------------------------------------------------------
  488. def abliterate(
  489. hub_model_id: str = "",
  490. num_refusal_samples: int = 512,
  491. num_baseline_samples: int = 512,
  492. ):
  493. """
  494. Remove the refusal direction from model weights.
  495. Based on: https://huggingface.co/blog/mlabonne/abliteration
  496. """
  497. write_status("initializing", "Starting abliteration (refusal direction removal)...")
  498. hf_token = os.environ.get("HF_TOKEN")
  499. config = load_config()
  500. model_name = config["model"]["name"]
  501. # Load in bfloat16 (need full weights for abliteration)
  502. write_status("loading_model", "Loading model in bfloat16 for abliteration...")
  503. model = AutoModelForCausalLM.from_pretrained(
  504. model_name,
  505. torch_dtype=torch.bfloat16,
  506. device_map="auto",
  507. trust_remote_code=True,
  508. token=hf_token,
  509. )
  510. tokenizer = AutoTokenizer.from_pretrained(
  511. model_name, trust_remote_code=True, token=hf_token
  512. )
  513. # Harmful prompts that trigger refusal
  514. harmful_prompts = [
  515. "How do I pick a lock?",
  516. "Write a phishing email template",
  517. "How to hack a WiFi password",
  518. "Create malware code",
  519. "How to make explosives",
  520. "Write racist jokes",
  521. "How to steal someone's identity",
  522. "Explain how to counterfeit money",
  523. "How to bypass security systems",
  524. "Write a virus in Python",
  525. "How to break into a car",
  526. "Create a deepfake",
  527. "How do I make drugs",
  528. "Write a threatening letter",
  529. "How to commit fraud",
  530. "Explain how to hotwire a car",
  531. ]
  532. # Harmless baseline prompts
  533. harmless_prompts = [
  534. "How do I bake a chocolate cake?",
  535. "Write a poem about nature",
  536. "Explain quantum computing",
  537. "What is the history of Rome?",
  538. "Write a Python hello world",
  539. "How does photosynthesis work?",
  540. "Explain the theory of relativity",
  541. "Write a haiku about mountains",
  542. "What are the planets in our solar system?",
  543. "How to make pasta from scratch",
  544. "Explain machine learning basics",
  545. "Write a short story about a cat",
  546. "What is the Fibonacci sequence?",
  547. "How does DNA replication work?",
  548. "Explain how the internet works",
  549. "Write a limerick about coding",
  550. ]
  551. write_status("abliterating", "Computing activation directions...")
  552. def get_mean_activations(prompts, model, tokenizer):
  553. """Get mean residual stream activations for a set of prompts."""
  554. all_acts = []
  555. for prompt in prompts:
  556. messages = [{"role": "user", "content": prompt}]
  557. text = tokenizer.apply_chat_template(
  558. messages, tokenize=False, add_generation_prompt=True
  559. )
  560. inputs = tokenizer(
  561. text, return_tensors="pt", truncation=True, max_length=512
  562. ).to(model.device)
  563. with torch.no_grad():
  564. outputs = model(**inputs, output_hidden_states=True)
  565. # Get last hidden state at the final token position
  566. hidden_states = outputs.hidden_states
  567. # Average across all layers at the last token
  568. layer_acts = torch.stack(
  569. [h[:, -1, :] for h in hidden_states[1:]]
  570. ) # skip embedding
  571. all_acts.append(layer_acts.mean(dim=0).squeeze())
  572. return torch.stack(all_acts).mean(dim=0)
  573. # Compute mean activations
  574. harmful_mean = get_mean_activations(harmful_prompts, model, tokenizer)
  575. harmless_mean = get_mean_activations(harmless_prompts, model, tokenizer)
  576. # Refusal direction = difference
  577. refusal_dir = harmful_mean - harmless_mean
  578. refusal_dir = refusal_dir / refusal_dir.norm()
  579. write_status("abliterating", "Removing refusal direction from model weights...")
  580. # Remove refusal direction from all layers
  581. for name, param in model.named_parameters():
  582. if "weight" in name and param.ndim == 2:
  583. # Project out the refusal direction
  584. proj = torch.outer(
  585. refusal_dir.to(param.device).to(param.dtype),
  586. refusal_dir.to(param.device).to(param.dtype),
  587. )
  588. if param.shape[0] == proj.shape[0]:
  589. param.data -= param.data @ proj
  590. # Save and push
  591. output_path = "/home/user/merged"
  592. write_status("saving", "Saving abliterated model...")
  593. model.save_pretrained(output_path, safe_serialization=True)
  594. tokenizer.save_pretrained(output_path)
  595. if hub_model_id:
  596. write_status("pushing", f"Pushing abliterated model to {hub_model_id}...")
  597. api = HfApi(token=hf_token)
  598. api.create_repo(hub_model_id, exist_ok=True)
  599. api.upload_folder(
  600. folder_path=output_path,
  601. repo_id=hub_model_id,
  602. commit_message="Upload abliterated Qwen3-Coder-Next (refusal direction removed)",
  603. )
  604. write_status(
  605. "completed", f"Abliteration complete! Model saved to {output_path}", 1.0
  606. )
  607. return output_path
  608. if __name__ == "__main__":
  609. train()