Selaa lähdekoodia

Deploy Qwen3-Coder-Next uncensored fine-tuner

Sameric 4 kuukautta sitten
vanhempi
sitoutus
8bbbc2053d
2 muutettua tiedostoa jossa 113 lisäystä ja 43 poistoa
  1. 11 7
      config.yaml
  2. 102 36
      train.py

+ 11 - 7
config.yaml

@@ -33,34 +33,38 @@ lora:
 
 # Dataset options (pick one or provide custom)
 datasets:
-  # Option 1: General uncensored instruction-following
+  # Option 1: General uncensored instruction-following (conversations format)
   wizard_vicuna:
     name: "ehartford/wizard_vicuna_70k_unfiltered"
     split: "train"
-    instruction_field: "instruction"
-    output_field: "output"
-    system_field: null
+    format: "conversations"          # [{from: "human"/"gpt", value: "..."}]
+    conversations_field: "conversations"
+    human_key: "human"
+    assistant_key: "gpt"
 
-  # Option 2: Toxic/uncensored DPO pairs
+  # Option 2: Toxic/uncensored DPO pairs (flat format)
   toxic_dpo:
     name: "NobodyExistsOnTheInternet/ToxicDPOqa"
     split: "train"
+    format: "flat"
     instruction_field: "prompt"
     output_field: "chosen"
     system_field: null
 
-  # Option 3: WizardLM uncensored
+  # Option 3: WizardLM uncensored (flat format)
   wizardlm_uncensored:
     name: "ehartford/WizardLM_alpaca_evol_instruct_70k_unfiltered"
     split: "train"
+    format: "flat"
     instruction_field: "instruction"
     output_field: "output"
     system_field: null
 
-  # Option 4: Synthia uncensored
+  # Option 4: Synthia uncensored (flat format)
   synthia:
     name: "migtissera/Synthia-v1.3"
     split: "train"
+    format: "flat"
     instruction_field: "instruction"
     output_field: "response"
     system_field: "system"

+ 102 - 36
train.py

@@ -126,23 +126,38 @@ def prepare_dataset(
 
         # Auto-detect fields
         cols = ds.column_names
-        instruction_field = next(
-            (
-                c
-                for c in ["instruction", "prompt", "input", "question", "user"]
-                if c in cols
-            ),
-            cols[0],
-        )
-        output_field = next(
-            (
-                c
-                for c in ["output", "response", "answer", "chosen", "assistant"]
-                if c in cols
-            ),
-            cols[1] if len(cols) > 1 else cols[0],
-        )
-        system_field = "system" if "system" in cols else None
+
+        # Check if it's a conversations-format dataset
+        if "conversations" in cols:
+            ds_format = "conversations"
+            conversations_field = "conversations"
+            human_key = "human"
+            assistant_key = "gpt"
+            instruction_field = None
+            output_field = None
+            system_field = None
+        else:
+            ds_format = "flat"
+            conversations_field = None
+            human_key = None
+            assistant_key = None
+            instruction_field = next(
+                (
+                    c
+                    for c in ["instruction", "prompt", "input", "question", "user"]
+                    if c in cols
+                ),
+                cols[0],
+            )
+            output_field = next(
+                (
+                    c
+                    for c in ["output", "response", "answer", "chosen", "assistant"]
+                    if c in cols
+                ),
+                cols[1] if len(cols) > 1 else cols[0],
+            )
+            system_field = "system" if "system" in cols else None
     else:
         ds_cfg = config["datasets"].get(dataset_name)
         if ds_cfg is None:
@@ -152,33 +167,79 @@ def prepare_dataset(
 
         logger.info(f"Loading dataset: {ds_cfg['name']}")
         ds = load_dataset(ds_cfg["name"], split=ds_cfg["split"])
-        instruction_field = ds_cfg["instruction_field"]
-        output_field = ds_cfg["output_field"]
-        system_field = ds_cfg.get("system_field")
+
+        ds_format = ds_cfg.get("format", "flat")
+        if ds_format == "conversations":
+            conversations_field = ds_cfg.get("conversations_field", "conversations")
+            human_key = ds_cfg.get("human_key", "human")
+            assistant_key = ds_cfg.get("assistant_key", "gpt")
+            instruction_field = None
+            output_field = None
+            system_field = None
+        else:
+            conversations_field = None
+            human_key = None
+            assistant_key = None
+            instruction_field = ds_cfg["instruction_field"]
+            output_field = ds_cfg["output_field"]
+            system_field = ds_cfg.get("system_field")
 
     if max_samples and max_samples < len(ds):
         ds = ds.shuffle(seed=42).select(range(max_samples))
 
-    logger.info(f"Dataset loaded: {len(ds)} samples")
+    logger.info(f"Dataset loaded: {len(ds)} samples (format: {ds_format})")
 
-    def format_chat(example):
-        messages = []
-        # System prompt
-        if system_field and example.get(system_field):
-            messages.append({"role": "system", "content": example[system_field]})
-        elif system_prompt:
-            messages.append({"role": "system", "content": system_prompt})
+    if ds_format == "conversations":
 
-        # User message
-        messages.append({"role": "user", "content": str(example[instruction_field])})
+        def format_chat(example):
+            """Convert conversations [{from, value}] into chat template."""
+            convos = example[conversations_field]
+            messages = []
 
-        # Assistant response
-        messages.append({"role": "assistant", "content": str(example[output_field])})
+            # Prepend system prompt
+            if system_prompt:
+                messages.append({"role": "system", "content": system_prompt})
 
-        text = tokenizer.apply_chat_template(
-            messages, tokenize=False, add_generation_prompt=False
-        )
-        return {"text": text}
+            for turn in convos:
+                role_key = turn.get("from", "")
+                value = turn.get("value", "")
+                if role_key == human_key:
+                    messages.append({"role": "user", "content": value})
+                elif role_key == assistant_key:
+                    messages.append({"role": "assistant", "content": value})
+
+            # Skip samples with no user/assistant turns
+            if len(messages) < 2 or not any(m["role"] == "assistant" for m in messages):
+                return {"text": ""}
+
+            text = tokenizer.apply_chat_template(
+                messages, tokenize=False, add_generation_prompt=False
+            )
+            return {"text": text}
+    else:
+
+        def format_chat(example):
+            messages = []
+            # System prompt
+            if system_field and example.get(system_field):
+                messages.append({"role": "system", "content": example[system_field]})
+            elif system_prompt:
+                messages.append({"role": "system", "content": system_prompt})
+
+            # User message
+            messages.append(
+                {"role": "user", "content": str(example[instruction_field])}
+            )
+
+            # Assistant response
+            messages.append(
+                {"role": "assistant", "content": str(example[output_field])}
+            )
+
+            text = tokenizer.apply_chat_template(
+                messages, tokenize=False, add_generation_prompt=False
+            )
+            return {"text": text}
 
     ds = ds.map(
         format_chat,
@@ -186,6 +247,11 @@ def prepare_dataset(
         num_proc=4,
         desc="Formatting dataset",
     )
+
+    # Filter out empty texts (e.g. conversations with no assistant turn)
+    ds = ds.filter(lambda x: len(x["text"].strip()) > 0)
+    logger.info(f"After filtering: {len(ds)} samples")
+
     return ds