فهرست منبع

Deploy Qwen3-Coder-Next uncensored fine-tuner

Sameric 4 ماه پیش
والد
کامیت
56667b9f65
1فایلهای تغییر یافته به همراه13 افزوده شده و 3 حذف شده
  1. 13 3
      train.py

+ 13 - 3
train.py

@@ -343,6 +343,18 @@ def train(
         bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
     )
 
+    # Pick best available attention: flash_attention_2 > sdpa > eager
+    if torch.cuda.is_available():
+        try:
+            import flash_attn  # noqa: F401
+
+            attn_impl = "flash_attention_2"
+        except ImportError:
+            attn_impl = "sdpa"  # PyTorch native, no extra install needed
+    else:
+        attn_impl = "eager"
+    logger.info(f"Using attention implementation: {attn_impl}")
+
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
         quantization_config=bnb_config,
@@ -350,9 +362,7 @@ def train(
         trust_remote_code=config["model"]["trust_remote_code"],
         torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
         token=hf_token,
-        attn_implementation="flash_attention_2"
-        if torch.cuda.is_available()
-        else "eager",
+        attn_implementation=attn_impl,
     )
 
     model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)