|
@@ -332,7 +332,7 @@ def train(
|
|
|
# -----------------------------------------------------------------------
|
|
# -----------------------------------------------------------------------
|
|
|
write_status(
|
|
write_status(
|
|
|
"loading_model",
|
|
"loading_model",
|
|
|
- "Loading Qwen3-Coder-Next in 4-bit quantization... (this takes a while)",
|
|
|
|
|
|
|
+ "Loading model in 4-bit quantization... (this takes a while)",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
q_cfg = config["quantization"]
|
|
q_cfg = config["quantization"]
|
|
@@ -355,18 +355,51 @@ def train(
|
|
|
attn_impl = "eager"
|
|
attn_impl = "eager"
|
|
|
logger.info(f"Using attention implementation: {attn_impl}")
|
|
logger.info(f"Using attention implementation: {attn_impl}")
|
|
|
|
|
|
|
|
|
|
+ # Log transformers version to confirm qwen3_next support
|
|
|
|
|
+ import transformers
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"transformers version: {transformers.__version__}")
|
|
|
|
|
+
|
|
|
|
|
+ # Pre-quantized fallback model for Qwen3-Next architecture
|
|
|
|
|
+ PRE_QUANTIZED_FALLBACK = "unsloth/Qwen3-Next-80B-A3B-Instruct-bnb-4bit"
|
|
|
|
|
+
|
|
|
# Force ALL layers onto GPU 0. device_map="auto" is too conservative
|
|
# Force ALL layers onto GPU 0. device_map="auto" is too conservative
|
|
|
# with large MoE models and offloads to CPU where bnb 4-bit can't run.
|
|
# with large MoE models and offloads to CPU where bnb 4-bit can't run.
|
|
|
# 80B params in 4-bit ≈ 40GB — fits comfortably on A100 80GB.
|
|
# 80B params in 4-bit ≈ 40GB — fits comfortably on A100 80GB.
|
|
|
- model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
- model_name,
|
|
|
|
|
- quantization_config=bnb_config,
|
|
|
|
|
- device_map={"": 0},
|
|
|
|
|
- trust_remote_code=config["model"]["trust_remote_code"],
|
|
|
|
|
- torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
|
|
|
|
|
- token=hf_token,
|
|
|
|
|
- attn_implementation=attn_impl,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Attempting to load {model_name} with on-the-fly 4-bit quantization..."
|
|
|
|
|
+ )
|
|
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
+ model_name,
|
|
|
|
|
+ quantization_config=bnb_config,
|
|
|
|
|
+ device_map={"": 0},
|
|
|
|
|
+ trust_remote_code=config["model"]["trust_remote_code"],
|
|
|
|
|
+ torch_dtype=getattr(torch, config["model"]["torch_dtype"]),
|
|
|
|
|
+ token=hf_token,
|
|
|
|
|
+ attn_implementation=attn_impl,
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(f"Successfully loaded {model_name} with 4-bit quantization")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"On-the-fly 4-bit quantization failed for {model_name}: {e}")
|
|
|
|
|
+ logger.info(f"Falling back to pre-quantized model: {PRE_QUANTIZED_FALLBACK}")
|
|
|
|
|
+ write_status(
|
|
|
|
|
+ "loading_model",
|
|
|
|
|
+ f"On-the-fly quantization failed, loading pre-quantized fallback...",
|
|
|
|
|
+ )
|
|
|
|
|
+ # Pre-quantized model already has bnb 4-bit weights baked in —
|
|
|
|
|
+ # do NOT pass quantization_config again, just load directly.
|
|
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
+ PRE_QUANTIZED_FALLBACK,
|
|
|
|
|
+ device_map={"": 0},
|
|
|
|
|
+ trust_remote_code=True,
|
|
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
|
|
+ token=hf_token,
|
|
|
|
|
+ attn_implementation=attn_impl,
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Successfully loaded pre-quantized fallback: {PRE_QUANTIZED_FALLBACK}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
|
|
logger.info("Model loaded and prepared for k-bit training")
|
|
logger.info("Model loaded and prepared for k-bit training")
|
|
@@ -377,10 +410,15 @@ def train(
|
|
|
write_status("loading_model", "Applying LoRA adapters...")
|
|
write_status("loading_model", "Applying LoRA adapters...")
|
|
|
|
|
|
|
|
lora_cfg = config["lora"]
|
|
lora_cfg = config["lora"]
|
|
|
|
|
+ target_modules = lora_cfg["target_modules"]
|
|
|
|
|
+ # target_modules can be a list of strings or the string "all-linear"
|
|
|
|
|
+ if isinstance(target_modules, str) and target_modules != "all-linear":
|
|
|
|
|
+ target_modules = [target_modules]
|
|
|
|
|
+
|
|
|
lora_config = LoraConfig(
|
|
lora_config = LoraConfig(
|
|
|
r=lora_cfg["r"],
|
|
r=lora_cfg["r"],
|
|
|
lora_alpha=lora_cfg["lora_alpha"],
|
|
lora_alpha=lora_cfg["lora_alpha"],
|
|
|
- target_modules=lora_cfg["target_modules"],
|
|
|
|
|
|
|
+ target_modules=target_modules,
|
|
|
lora_dropout=lora_cfg["lora_dropout"],
|
|
lora_dropout=lora_cfg["lora_dropout"],
|
|
|
bias=lora_cfg["bias"],
|
|
bias=lora_cfg["bias"],
|
|
|
task_type=lora_cfg["task_type"],
|
|
task_type=lora_cfg["task_type"],
|