Przeglądaj źródła

fix: aggressive OOM prevention - lora_r=16, seq_len=512, max_samples=5000, clear cache

Sameric 4 miesięcy temu
rodzic
commit
bac3cf12ad
1 zmienionych plików z 16 dodań i 7 usunięć
  1. 16 7
      train.py

+ 16 - 7
train.py

@@ -263,15 +263,15 @@ def prepare_dataset(
 def train(
     dataset_choice: str = "wizard_vicuna",
     hub_model_id: str = "",
-    max_samples: Optional[int] = None,
+    max_samples: Optional[int] = 5000,
     custom_dataset_path: Optional[str] = None,
     num_epochs: int = 2,
     learning_rate: float = 2e-4,
-    lora_r: int = 64,
-    lora_alpha: int = 128,
+    lora_r: int = 16,
+    lora_alpha: int = 32,
     batch_size: int = 1,
-    grad_accum: int = 16,
-    max_seq_length: int = 1024,
+    grad_accum: int = 8,
+    max_seq_length: int = 512,
     system_prompt: str = "",
 ):
     """Run the full QLoRA fine-tuning pipeline."""
@@ -470,8 +470,8 @@ def train(
         if os.environ.get("WANDB_API_KEY")
         else "none",
         seed=t_cfg["seed"],
-        dataloader_num_workers=t_cfg.get("dataloader_num_workers", 4),
-        dataloader_pin_memory=t_cfg.get("dataloader_pin_memory", True),
+        dataloader_num_workers=0,
+        dataloader_pin_memory=False,
         # packing=False because sdpa attention + packing is unsupported
         # and causes silent crashes on Qwen3-Next architecture.
         # flash_attention_2 would fix this but flash-attn is hard to compile
@@ -524,6 +524,15 @@ def train(
 
     import traceback as _tb
 
+    # Free any cached CUDA memory before training starts
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        torch.cuda.reset_peak_memory_stats()
+        logger.info(
+            f"Pre-train VRAM: {torch.cuda.memory_allocated(0) / 1e9:.1f} GB allocated, "
+            f"{torch.cuda.memory_reserved(0) / 1e9:.1f} GB reserved"
+        )
+
     try:
         logger.info("Calling trainer.train() ...")
         train_result = trainer.train()