|
|
@@ -355,10 +355,21 @@ def train(
|
|
|
attn_impl = "eager"
|
|
|
logger.info(f"Using attention implementation: {attn_impl}")
|
|
|
|
|
|
+ # Force model onto GPU — 80B in 4-bit ≈ 40GB, fits on A100 80GB.
|
|
|
+ # Without max_memory, device_map="auto" is too conservative and
|
|
|
+ # offloads layers to CPU where bitsandbytes 4-bit cannot run.
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ gpu_mem = torch.cuda.get_device_properties(0).total_mem / (1024**3)
|
|
|
+ max_memory = {0: f"{int(gpu_mem * 0.92)}GiB", "cpu": "30GiB"}
|
|
|
+ logger.info(f"GPU memory: {gpu_mem:.1f} GiB, max_memory: {max_memory}")
|
|
|
+ else:
|
|
|
+ max_memory = None
|
|
|
+
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
device_map="auto",
|
|
|
+ max_memory=max_memory,
|
|
|
trust_remote_code=config["model"]["trust_remote_code"],
|
|
|
torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
|
|
|
token=hf_token,
|