|
|
@@ -263,15 +263,15 @@ def prepare_dataset(
|
|
|
def train(
|
|
|
dataset_choice: str = "wizard_vicuna",
|
|
|
hub_model_id: str = "",
|
|
|
- max_samples: Optional[int] = None,
|
|
|
+ max_samples: Optional[int] = 5000,
|
|
|
custom_dataset_path: Optional[str] = None,
|
|
|
num_epochs: int = 2,
|
|
|
learning_rate: float = 2e-4,
|
|
|
- lora_r: int = 64,
|
|
|
- lora_alpha: int = 128,
|
|
|
+ lora_r: int = 16,
|
|
|
+ lora_alpha: int = 32,
|
|
|
batch_size: int = 1,
|
|
|
- grad_accum: int = 16,
|
|
|
- max_seq_length: int = 1024,
|
|
|
+ grad_accum: int = 8,
|
|
|
+ max_seq_length: int = 512,
|
|
|
system_prompt: str = "",
|
|
|
):
|
|
|
"""Run the full QLoRA fine-tuning pipeline."""
|
|
|
@@ -470,8 +470,8 @@ def train(
|
|
|
if os.environ.get("WANDB_API_KEY")
|
|
|
else "none",
|
|
|
seed=t_cfg["seed"],
|
|
|
- dataloader_num_workers=t_cfg.get("dataloader_num_workers", 4),
|
|
|
- dataloader_pin_memory=t_cfg.get("dataloader_pin_memory", True),
|
|
|
+ dataloader_num_workers=0,
|
|
|
+ dataloader_pin_memory=False,
|
|
|
# packing=False because sdpa attention + packing is unsupported
|
|
|
# and causes silent crashes on Qwen3-Next architecture.
|
|
|
# flash_attention_2 would fix this but flash-attn is hard to compile
|
|
|
@@ -524,6 +524,15 @@ def train(
|
|
|
|
|
|
import traceback as _tb
|
|
|
|
|
|
+ # Free any cached CUDA memory before training starts
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ torch.cuda.reset_peak_memory_stats()
|
|
|
+ logger.info(
|
|
|
+ f"Pre-train VRAM: {torch.cuda.memory_allocated(0) / 1e9:.1f} GB allocated, "
|
|
|
+ f"{torch.cuda.memory_reserved(0) / 1e9:.1f} GB reserved"
|
|
|
+ )
|
|
|
+
|
|
|
try:
|
|
|
logger.info("Calling trainer.train() ...")
|
|
|
train_result = trainer.train()
|