Procházet zdrojové kódy

Deploy Qwen3-Coder-Next uncensored fine-tuner

Sameric před 4 měsíci
rodič
revize
56667b9f65
1 změnil soubory, kde provedl 13 přidání a 3 odebrání
  1. 13 3
      train.py

+ 13 - 3
train.py

@@ -343,6 +343,18 @@ def train(
         bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
         bnb_4bit_use_double_quant=q_cfg["bnb_4bit_use_double_quant"],
     )
     )
 
 
+    # Pick best available attention: flash_attention_2 > sdpa > eager
+    if torch.cuda.is_available():
+        try:
+            import flash_attn  # noqa: F401
+
+            attn_impl = "flash_attention_2"
+        except ImportError:
+            attn_impl = "sdpa"  # PyTorch native, no extra install needed
+    else:
+        attn_impl = "eager"
+    logger.info(f"Using attention implementation: {attn_impl}")
+
     model = AutoModelForCausalLM.from_pretrained(
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
         model_name,
         quantization_config=bnb_config,
         quantization_config=bnb_config,
@@ -350,9 +362,7 @@ def train(
         trust_remote_code=config["model"]["trust_remote_code"],
         trust_remote_code=config["model"]["trust_remote_code"],
         torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
         torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
         token=hf_token,
         token=hf_token,
-        attn_implementation="flash_attention_2"
-        if torch.cuda.is_available()
-        else "eager",
+        attn_implementation=attn_impl,
     )
     )
 
 
     model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
     model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)