Просмотр исходного кода

Update train.py to Unsloth FastModel

Sameric 4 месяцев назад
Родитель
Сommit
bc8de7d068
1 измененных файлов с 172 добавлено и 197 удалено
  1. 172 197
      train.py

+ 172 - 197
train.py

@@ -1,6 +1,7 @@
 """
-Qwen3-Coder-Next Uncensored Fine-Tuning Script
-QLoRA 4-bit fine-tuning with TRL's SFTTrainer
+Qwen3 Uncensored Fine-Tuning Script (Unsloth)
+QLoRA 4-bit fine-tuning with Unsloth's FastModel + TRL SFTTrainer
+Uses Qwen3-30B-A3B (30B total, 3B active MoE) - fits in ~17.5GB VRAM
 """
 
 import os
@@ -11,18 +12,8 @@ import torch
 import logging
 from pathlib import Path
 from typing import Optional
-from dataclasses import dataclass
 
-from transformers import (
-    AutoModelForCausalLM,
-    AutoTokenizer,
-    BitsAndBytesConfig,
-    TrainerCallback,
-)
-from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
-from trl import SFTTrainer, SFTConfig
-from datasets import load_dataset, Dataset, concatenate_datasets
-from huggingface_hub import HfApi
+from transformers import TrainerCallback
 
 logging.basicConfig(
     level=logging.INFO,
@@ -108,8 +99,9 @@ def prepare_dataset(
     system_prompt: str,
     max_samples: Optional[int] = None,
     custom_dataset_path: Optional[str] = None,
-) -> Dataset:
+):
     """Load and format dataset into chat-template strings."""
+    from datasets import load_dataset
 
     if custom_dataset_path:
         logger.info(f"Loading custom dataset from {custom_dataset_path}")
@@ -274,7 +266,7 @@ def train(
     max_seq_length: int = 512,
     system_prompt: str = "",
 ):
-    """Run the full QLoRA fine-tuning pipeline."""
+    """Run QLoRA fine-tuning using Unsloth FastModel."""
 
     config = load_config()
     write_status("initializing", "Loading configuration...")
@@ -294,27 +286,48 @@ def train(
     hf_token = os.environ.get("HF_TOKEN")
     if not hf_token:
         write_status(
-            "error", "HF_TOKEN secret not set! Add it in Space Settings  Secrets."
+            "error", "HF_TOKEN secret not set! Add it in Space Settings -> Secrets."
         )
         raise ValueError("HF_TOKEN environment variable is required")
 
-    # -----------------------------------------------------------------------
-    # 1. Load tokenizer
-    # -----------------------------------------------------------------------
-    write_status("initializing", "Loading tokenizer...")
+    # -------------------------------------------------------------------
+    # 1. Load model with Unsloth FastModel (4-bit QLoRA)
+    # -------------------------------------------------------------------
+    write_status(
+        "loading_model",
+        "Loading Qwen3-30B-A3B with Unsloth (4-bit)... "
+        "MoE models download full 16-bit then convert to 4-bit on-the-fly.",
+    )
+
+    from unsloth import FastModel
+
     model_name = config["model"]["name"]
-    tokenizer = AutoTokenizer.from_pretrained(
-        model_name,
-        trust_remote_code=config["model"]["trust_remote_code"],
+    logger.info(f"Loading model: {model_name} with Unsloth FastModel")
+
+    model, tokenizer = FastModel.from_pretrained(
+        model_name=model_name,
+        max_seq_length=max_seq_length,
+        load_in_4bit=True,
+        load_in_8bit=False,
+        full_finetuning=False,
         token=hf_token,
     )
+
+    logger.info("Model loaded successfully with Unsloth")
+
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
     tokenizer.padding_side = "right"
 
-    # -----------------------------------------------------------------------
+    if torch.cuda.is_available():
+        logger.info(
+            f"Post-load VRAM: {torch.cuda.memory_allocated(0) / 1e9:.1f} GB allocated, "
+            f"{torch.cuda.memory_reserved(0) / 1e9:.1f} GB reserved"
+        )
+
+    # -------------------------------------------------------------------
     # 2. Load dataset
-    # -----------------------------------------------------------------------
+    # -------------------------------------------------------------------
     write_status("initializing", "Loading and formatting dataset...")
     dataset = prepare_dataset(
         dataset_name=dataset_choice,
@@ -327,120 +340,74 @@ def train(
     logger.info(f"Formatted dataset: {len(dataset)} samples")
     logger.info(f"Sample:\n{dataset[0]['text'][:500]}...")
 
-    # -----------------------------------------------------------------------
-    # 3. Load model in 4-bit
-    # -----------------------------------------------------------------------
-    write_status(
-        "loading_model",
-        "Loading model in 4-bit quantization... (this takes a while)",
-    )
-
-    q_cfg = config["quantization"]
-    bnb_config = BitsAndBytesConfig(
-        load_in_4bit=q_cfg["load_in_4bit"],
-        bnb_4bit_quant_type=q_cfg["bnb_4bit_quant_type"],
-        bnb_4bit_compute_dtype=getattr(torch, q_cfg["bnb_4bit_compute_dtype"]),
-        bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
-    )
-
-    # Pick best available attention: flash_attention_2 > sdpa > eager
-    if torch.cuda.is_available():
-        try:
-            import flash_attn  # noqa: F401
-
-            attn_impl = "flash_attention_2"
-        except ImportError:
-            attn_impl = "sdpa"  # PyTorch native, no extra install needed
-    else:
-        attn_impl = "eager"
-    logger.info(f"Using attention implementation: {attn_impl}")
-
-    # Log transformers version to confirm qwen3_next support
-    import transformers
-
-    logger.info(f"transformers version: {transformers.__version__}")
-
-    # Pre-quantized fallback model for Qwen3-Next architecture
-    PRE_QUANTIZED_FALLBACK = "unsloth/Qwen3-Next-80B-A3B-Instruct-bnb-4bit"
-
-    # Force ALL layers onto GPU 0.  bnb 4-bit layers cannot run on CPU.
-    # With expandable_segments=True (set via PYTORCH_CUDA_ALLOC_CONF env),
-    # PyTorch won't pre-reserve all 80GB upfront, leaving room for activations.
-    try:
-        logger.info(
-            f"Attempting to load {model_name} with on-the-fly 4-bit quantization..."
-        )
-        model = AutoModelForCausalLM.from_pretrained(
-            model_name,
-            quantization_config=bnb_config,
-            device_map={"": 0},
-            trust_remote_code=config["model"]["trust_remote_code"],
-            torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
-            token=hf_token,
-            attn_implementation=attn_impl,
-        )
-        logger.info(f"Successfully loaded {model_name} with 4-bit quantization")
-    except Exception as e:
-        logger.warning(f"On-the-fly 4-bit quantization failed for {model_name}: {e}")
-        logger.info(f"Falling back to pre-quantized model: {PRE_QUANTIZED_FALLBACK}")
-        write_status(
-            "loading_model",
-            f"On-the-fly quantization failed, loading pre-quantized fallback...",
-        )
-        # Pre-quantized model already has bnb 4-bit weights baked in —
-        # do NOT pass quantization_config again, just load directly.
-        model = AutoModelForCausalLM.from_pretrained(
-            PRE_QUANTIZED_FALLBACK,
-            device_map={"": 0},
-            trust_remote_code=True,
-            torch_dtype=torch.bfloat16,
-            token=hf_token,
-            attn_implementation=attn_impl,
-        )
-        logger.info(
-            f"Successfully loaded pre-quantized fallback: {PRE_QUANTIZED_FALLBACK}"
-        )
-
-    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
-    logger.info("Model loaded and prepared for k-bit training")
-
-    # -----------------------------------------------------------------------
-    # 4. Apply LoRA
-    # -----------------------------------------------------------------------
-    write_status("loading_model", "Applying LoRA adapters...")
+    # -------------------------------------------------------------------
+    # 3. Apply LoRA via Unsloth
+    # -------------------------------------------------------------------
+    write_status("loading_model", "Applying LoRA adapters via Unsloth...")
 
     lora_cfg = config["lora"]
-    target_modules = lora_cfg["target_modules"]
-    # target_modules can be a list of strings or the string "all-linear"
-    if isinstance(target_modules, str) and target_modules != "all-linear":
-        target_modules = [target_modules]
-
-    lora_config = LoraConfig(
+    target_modules = lora_cfg.get(
+        "target_modules",
+        [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "up_proj",
+            "down_proj",
+        ],
+    )
+    # Unsloth get_peft_model expects a list, not "all-linear"
+    if isinstance(target_modules, str) and target_modules == "all-linear":
+        target_modules = [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "up_proj",
+            "down_proj",
+        ]
+
+    model = FastModel.get_peft_model(
+        model,
         r=lora_cfg["r"],
-        lora_alpha=lora_cfg["lora_alpha"],
         target_modules=target_modules,
-        lora_dropout=lora_cfg["lora_dropout"],
-        bias=lora_cfg["bias"],
-        task_type=lora_cfg["task_type"],
+        lora_alpha=lora_cfg["lora_alpha"],
+        lora_dropout=lora_cfg.get("lora_dropout", 0),
+        bias=lora_cfg.get("bias", "none"),
+        use_gradient_checkpointing="unsloth",  # Unsloth optimized
+        random_state=42,
     )
 
-    model = get_peft_model(model, lora_config)
-    trainable, total = model.get_nb_trainable_parameters()
+    # Log trainable params
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    total_params = sum(p.numel() for p in model.parameters())
     logger.info(
-        f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)"
+        f"Trainable params: {trainable_params:,} / {total_params:,} "
+        f"({100 * trainable_params / total_params:.2f}%)"
     )
     write_status(
         "loading_model",
-        f"LoRA applied: {trainable:,} trainable params ({100 * trainable / total:.2f}%)",
+        f"LoRA applied: {trainable_params:,} trainable params "
+        f"({100 * trainable_params / total_params:.2f}%)",
     )
 
-    # -----------------------------------------------------------------------
-    # 5. Training arguments
-    # -----------------------------------------------------------------------
+    if torch.cuda.is_available():
+        logger.info(
+            f"Post-LoRA VRAM: {torch.cuda.memory_allocated(0) / 1e9:.1f} GB allocated, "
+            f"{torch.cuda.memory_reserved(0) / 1e9:.1f} GB reserved"
+        )
+
+    # -------------------------------------------------------------------
+    # 4. Training arguments
+    # -------------------------------------------------------------------
+    from trl import SFTTrainer, SFTConfig
+
     t_cfg = config["training"]
     output_dir = t_cfg["output_dir"]
 
-    # Determine hub settings
     push_to_hub = bool(hub_model_id)
     hub_cfg = config.get("hub", {})
 
@@ -450,44 +417,31 @@ def train(
         per_device_train_batch_size=t_cfg["per_device_train_batch_size"],
         gradient_accumulation_steps=t_cfg["gradient_accumulation_steps"],
         learning_rate=t_cfg["learning_rate"],
-        lr_scheduler_type=t_cfg["lr_scheduler_type"],
-        warmup_ratio=t_cfg["warmup_ratio"],
-        weight_decay=t_cfg["weight_decay"],
-        bf16=t_cfg["bf16"],
+        lr_scheduler_type=t_cfg.get("lr_scheduler_type", "cosine"),
+        warmup_ratio=t_cfg.get("warmup_ratio", 0.05),
+        weight_decay=t_cfg.get("weight_decay", 0.01),
+        bf16=t_cfg.get("bf16", True),
         tf32=t_cfg.get("tf32", True),
-        max_grad_norm=t_cfg["max_grad_norm"],
-        logging_steps=t_cfg["logging_steps"],
-        save_strategy=t_cfg["save_strategy"],
-        save_steps=t_cfg["save_steps"],
-        save_total_limit=t_cfg["save_total_limit"],
-        max_length=t_cfg["max_seq_length"],
-        gradient_checkpointing=t_cfg["gradient_checkpointing"],
-        gradient_checkpointing_kwargs=t_cfg.get(
-            "gradient_checkpointing_kwargs", {"use_reentrant": False}
-        ),
-        optim=t_cfg["optim"],
-        report_to=t_cfg.get("report_to", "none")
-        if os.environ.get("WANDB_API_KEY")
-        else "none",
-        seed=t_cfg["seed"],
-        dataloader_num_workers=0,
-        dataloader_pin_memory=False,
-        # packing=False because sdpa attention + packing is unsupported
-        # and causes silent crashes on Qwen3-Next architecture.
-        # flash_attention_2 would fix this but flash-attn is hard to compile
-        # in Docker. Disabling packing is the safest fix.
+        max_grad_norm=t_cfg.get("max_grad_norm", 1.0),
+        logging_steps=t_cfg.get("logging_steps", 5),
+        save_strategy=t_cfg.get("save_strategy", "steps"),
+        save_steps=t_cfg.get("save_steps", 50),
+        save_total_limit=t_cfg.get("save_total_limit", 3),
+        max_length=max_seq_length,
         packing=False,
         dataset_text_field="text",
-        push_to_hub=push_to_hub,
-        hub_model_id=hub_model_id if push_to_hub else None,
-        hub_strategy=hub_cfg.get("hub_strategy", "checkpoint"),
-        hub_private_repo=hub_cfg.get("hub_private_repo", False),
-        hub_token=hf_token,
+        optim="adamw_8bit",
+        report_to="none",
+        seed=t_cfg.get("seed", 42),
+        dataloader_num_workers=0,
+        dataloader_pin_memory=False,
+        # Don't push via SFTTrainer - we use Unsloth's push_to_hub_merged
+        push_to_hub=False,
     )
 
-    # -----------------------------------------------------------------------
-    # 6. Trainer
-    # -----------------------------------------------------------------------
+    # -------------------------------------------------------------------
+    # 5. Trainer
+    # -------------------------------------------------------------------
     total_steps = (
         len(dataset)
         // (t_cfg["per_device_train_batch_size"] * t_cfg["gradient_accumulation_steps"])
@@ -502,12 +456,13 @@ def train(
         callbacks=[StatusCallback(total_steps)],
     )
 
-    # -----------------------------------------------------------------------
-    # 7. Train!
-    # -----------------------------------------------------------------------
+    # -------------------------------------------------------------------
+    # 6. Train!
+    # -------------------------------------------------------------------
     write_status("training", "Starting training...", 0.0)
     logger.info("=" * 60)
-    logger.info("TRAINING STARTED")
+    logger.info("TRAINING STARTED (Unsloth FastModel)")
+    logger.info(f"  Model: {model_name}")
     logger.info(f"  Dataset: {len(dataset)} samples")
     logger.info(f"  Epochs: {t_cfg['num_train_epochs']}")
     logger.info(f"  Batch size: {t_cfg['per_device_train_batch_size']}")
@@ -517,14 +472,11 @@ def train(
     )
     logger.info(f"  LR: {t_cfg['learning_rate']}")
     logger.info(f"  LoRA r={lora_cfg['r']}, alpha={lora_cfg['lora_alpha']}")
-    logger.info(f"  Max seq length: {t_cfg['max_seq_length']}")
+    logger.info(f"  Max seq length: {max_seq_length}")
     logger.info(f"  Total steps: ~{total_steps}")
-    logger.info(f"  Push to hub: {push_to_hub}  {hub_model_id}")
+    logger.info(f"  Push to hub: {push_to_hub} -> {hub_model_id}")
     logger.info("=" * 60)
 
-    import traceback as _tb
-
-    # Free any cached CUDA memory before training starts
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
         torch.cuda.reset_peak_memory_stats()
@@ -533,6 +485,8 @@ def train(
             f"{torch.cuda.memory_reserved(0) / 1e9:.1f} GB reserved"
         )
 
+    import traceback as _tb
+
     try:
         logger.info("Calling trainer.train() ...")
         train_result = trainer.train()
@@ -542,38 +496,55 @@ def train(
         full_tb = _tb.format_exc()
         logger.error(err_msg)
         logger.error(full_tb)
-        # Also write to a persistent crash file
         crash_path = "/home/user/crash.log"
         with open(crash_path, "w") as cf:
             cf.write(f"{err_msg}\n\n{full_tb}")
         write_status("error", err_msg, 0.0, {"traceback": full_tb[:2000]})
         raise
 
-    # -----------------------------------------------------------------------
-    # 8. Save final adapter
-    # -----------------------------------------------------------------------
-    write_status("saving", "Saving final LoRA adapter...")
-    final_adapter_path = os.path.join(output_dir, "final_adapter")
-    trainer.save_model(final_adapter_path)
-    tokenizer.save_pretrained(final_adapter_path)
+    # -------------------------------------------------------------------
+    # 7. Save and push LoRA adapter
+    # -------------------------------------------------------------------
+    write_status("saving", "Saving LoRA adapter...")
+
+    local_lora_path = os.path.join(output_dir, "final_adapter")
+    model.save_pretrained(local_lora_path)
+    tokenizer.save_pretrained(local_lora_path)
+    logger.info(f"LoRA adapter saved locally to {local_lora_path}")
 
-    # Push adapter to Hub
     if push_to_hub and hub_model_id:
         write_status("pushing", f"Pushing LoRA adapter to {hub_model_id}...")
-        api = HfApi(token=hf_token)
-        api.create_repo(
-            hub_model_id, exist_ok=True, private=hub_cfg.get("hub_private_repo", False)
-        )
-        api.upload_folder(
-            folder_path=final_adapter_path,
-            repo_id=hub_model_id,
-            commit_message="Upload QLoRA adapter — Qwen3-Coder-Next uncensored",
-        )
-        logger.info(f"Adapter pushed to https://huggingface.co/{hub_model_id}")
+        try:
+            model.push_to_hub_merged(
+                hub_model_id,
+                tokenizer,
+                save_method="lora",
+                token=hf_token,
+            )
+            logger.info(f"LoRA adapter pushed to https://huggingface.co/{hub_model_id}")
+        except Exception as push_exc:
+            # Fallback: manual upload via HfApi
+            logger.warning(
+                f"push_to_hub_merged failed: {push_exc}, trying manual upload"
+            )
+            from huggingface_hub import HfApi
+
+            api = HfApi(token=hf_token)
+            api.create_repo(
+                hub_model_id,
+                exist_ok=True,
+                private=hub_cfg.get("hub_private_repo", False),
+            )
+            api.upload_folder(
+                folder_path=local_lora_path,
+                repo_id=hub_model_id,
+                commit_message="Upload QLoRA adapter - Qwen3 uncensored (Unsloth)",
+            )
+            logger.info(f"Adapter uploaded via HfApi to {hub_model_id}")
 
     write_status(
         "completed",
-        f"Training complete! Adapter saved to {final_adapter_path}",
+        f"Training complete! Adapter saved to {local_lora_path}",
         1.0,
         {
             "train_loss": round(train_result.metrics.get("train_loss", 0), 4),
@@ -584,7 +555,7 @@ def train(
         },
     )
 
-    return final_adapter_path
+    return local_lora_path
 
 
 # ---------------------------------------------------------------------------
@@ -600,6 +571,7 @@ def abliterate(
     """
     Remove the refusal direction from model weights.
     Based on: https://huggingface.co/blog/mlabonne/abliteration
+    Uses Unsloth FastModel to load the model efficiently.
     """
     write_status("initializing", "Starting abliteration (refusal direction removal)...")
 
@@ -607,18 +579,19 @@ def abliterate(
     config = load_config()
     model_name = config["model"]["name"]
 
-    # Load in bfloat16 (need full weights for abliteration)
+    # Load in bfloat16 (need full weights for abliteration — no 4-bit)
     write_status("loading_model", "Loading model in bfloat16 for abliteration...")
-    model = AutoModelForCausalLM.from_pretrained(
-        model_name,
-        torch_dtype=torch.bfloat16,
-        device_map="auto",
-        trust_remote_code=True,
+
+    from unsloth import FastModel
+
+    model, tokenizer = FastModel.from_pretrained(
+        model_name=model_name,
+        max_seq_length=2048,
+        load_in_4bit=False,
+        load_in_8bit=False,
+        full_finetuning=False,
         token=hf_token,
     )
-    tokenizer = AutoTokenizer.from_pretrained(
-        model_name, trust_remote_code=True, token=hf_token
-    )
 
     # Harmful prompts that trigger refusal
     harmful_prompts = [
@@ -716,12 +689,14 @@ def abliterate(
 
     if hub_model_id:
         write_status("pushing", f"Pushing abliterated model to {hub_model_id}...")
+        from huggingface_hub import HfApi
+
         api = HfApi(token=hf_token)
         api.create_repo(hub_model_id, exist_ok=True)
         api.upload_folder(
             folder_path=output_path,
             repo_id=hub_model_id,
-            commit_message="Upload abliterated Qwen3-Coder-Next (refusal direction removed)",
+            commit_message="Upload abliterated Qwen3 (refusal direction removed)",
         )
 
     write_status(