|
|
@@ -343,6 +343,18 @@ def train(
|
|
|
bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
|
|
|
)
|
|
|
|
|
|
+ # Pick best available attention: flash_attention_2 > sdpa > eager
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ try:
|
|
|
+ import flash_attn # noqa: F401
|
|
|
+
|
|
|
+ attn_impl = "flash_attention_2"
|
|
|
+ except ImportError:
|
|
|
+ attn_impl = "sdpa" # PyTorch native, no extra install needed
|
|
|
+ else:
|
|
|
+ attn_impl = "eager"
|
|
|
+ logger.info(f"Using attention implementation: {attn_impl}")
|
|
|
+
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
@@ -350,9 +362,7 @@ def train(
|
|
|
trust_remote_code=config["model"]["trust_remote_code"],
|
|
|
torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
|
|
|
token=hf_token,
|
|
|
- attn_implementation="flash_attention_2"
|
|
|
- if torch.cuda.is_available()
|
|
|
- else "eager",
|
|
|
+ attn_implementation=attn_impl,
|
|
|
)
|
|
|
|
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|