|
|
@@ -126,23 +126,38 @@ def prepare_dataset(
|
|
|
|
|
|
# Auto-detect fields
|
|
|
cols = ds.column_names
|
|
|
- instruction_field = next(
|
|
|
- (
|
|
|
- c
|
|
|
- for c in ["instruction", "prompt", "input", "question", "user"]
|
|
|
- if c in cols
|
|
|
- ),
|
|
|
- cols[0],
|
|
|
- )
|
|
|
- output_field = next(
|
|
|
- (
|
|
|
- c
|
|
|
- for c in ["output", "response", "answer", "chosen", "assistant"]
|
|
|
- if c in cols
|
|
|
- ),
|
|
|
- cols[1] if len(cols) > 1 else cols[0],
|
|
|
- )
|
|
|
- system_field = "system" if "system" in cols else None
|
|
|
+
|
|
|
+ # Check if it's a conversations-format dataset
|
|
|
+ if "conversations" in cols:
|
|
|
+ ds_format = "conversations"
|
|
|
+ conversations_field = "conversations"
|
|
|
+ human_key = "human"
|
|
|
+ assistant_key = "gpt"
|
|
|
+ instruction_field = None
|
|
|
+ output_field = None
|
|
|
+ system_field = None
|
|
|
+ else:
|
|
|
+ ds_format = "flat"
|
|
|
+ conversations_field = None
|
|
|
+ human_key = None
|
|
|
+ assistant_key = None
|
|
|
+ instruction_field = next(
|
|
|
+ (
|
|
|
+ c
|
|
|
+ for c in ["instruction", "prompt", "input", "question", "user"]
|
|
|
+ if c in cols
|
|
|
+ ),
|
|
|
+ cols[0],
|
|
|
+ )
|
|
|
+ output_field = next(
|
|
|
+ (
|
|
|
+ c
|
|
|
+ for c in ["output", "response", "answer", "chosen", "assistant"]
|
|
|
+ if c in cols
|
|
|
+ ),
|
|
|
+ cols[1] if len(cols) > 1 else cols[0],
|
|
|
+ )
|
|
|
+ system_field = "system" if "system" in cols else None
|
|
|
else:
|
|
|
ds_cfg = config["datasets"].get(dataset_name)
|
|
|
if ds_cfg is None:
|
|
|
@@ -152,33 +167,79 @@ def prepare_dataset(
|
|
|
|
|
|
logger.info(f"Loading dataset: {ds_cfg['name']}")
|
|
|
ds = load_dataset(ds_cfg["name"], split=ds_cfg["split"])
|
|
|
- instruction_field = ds_cfg["instruction_field"]
|
|
|
- output_field = ds_cfg["output_field"]
|
|
|
- system_field = ds_cfg.get("system_field")
|
|
|
+
|
|
|
+ ds_format = ds_cfg.get("format", "flat")
|
|
|
+ if ds_format == "conversations":
|
|
|
+ conversations_field = ds_cfg.get("conversations_field", "conversations")
|
|
|
+ human_key = ds_cfg.get("human_key", "human")
|
|
|
+ assistant_key = ds_cfg.get("assistant_key", "gpt")
|
|
|
+ instruction_field = None
|
|
|
+ output_field = None
|
|
|
+ system_field = None
|
|
|
+ else:
|
|
|
+ conversations_field = None
|
|
|
+ human_key = None
|
|
|
+ assistant_key = None
|
|
|
+ instruction_field = ds_cfg["instruction_field"]
|
|
|
+ output_field = ds_cfg["output_field"]
|
|
|
+ system_field = ds_cfg.get("system_field")
|
|
|
|
|
|
if max_samples and max_samples < len(ds):
|
|
|
ds = ds.shuffle(seed=42).select(range(max_samples))
|
|
|
|
|
|
- logger.info(f"Dataset loaded: {len(ds)} samples")
|
|
|
+ logger.info(f"Dataset loaded: {len(ds)} samples (format: {ds_format})")
|
|
|
|
|
|
- def format_chat(example):
|
|
|
- messages = []
|
|
|
- # System prompt
|
|
|
- if system_field and example.get(system_field):
|
|
|
- messages.append({"role": "system", "content": example[system_field]})
|
|
|
- elif system_prompt:
|
|
|
- messages.append({"role": "system", "content": system_prompt})
|
|
|
+ if ds_format == "conversations":
|
|
|
|
|
|
- # User message
|
|
|
- messages.append({"role": "user", "content": str(example[instruction_field])})
|
|
|
+ def format_chat(example):
|
|
|
+ """Convert conversations [{from, value}] into chat template."""
|
|
|
+ convos = example[conversations_field]
|
|
|
+ messages = []
|
|
|
|
|
|
- # Assistant response
|
|
|
- messages.append({"role": "assistant", "content": str(example[output_field])})
|
|
|
+ # Prepend system prompt
|
|
|
+ if system_prompt:
|
|
|
+ messages.append({"role": "system", "content": system_prompt})
|
|
|
|
|
|
- text = tokenizer.apply_chat_template(
|
|
|
- messages, tokenize=False, add_generation_prompt=False
|
|
|
- )
|
|
|
- return {"text": text}
|
|
|
+ for turn in convos:
|
|
|
+ role_key = turn.get("from", "")
|
|
|
+ value = turn.get("value", "")
|
|
|
+ if role_key == human_key:
|
|
|
+ messages.append({"role": "user", "content": value})
|
|
|
+ elif role_key == assistant_key:
|
|
|
+ messages.append({"role": "assistant", "content": value})
|
|
|
+
|
|
|
+ # Skip samples with no user/assistant turns
|
|
|
+ if len(messages) < 2 or not any(m["role"] == "assistant" for m in messages):
|
|
|
+ return {"text": ""}
|
|
|
+
|
|
|
+ text = tokenizer.apply_chat_template(
|
|
|
+ messages, tokenize=False, add_generation_prompt=False
|
|
|
+ )
|
|
|
+ return {"text": text}
|
|
|
+ else:
|
|
|
+
|
|
|
+ def format_chat(example):
|
|
|
+ messages = []
|
|
|
+ # System prompt
|
|
|
+ if system_field and example.get(system_field):
|
|
|
+ messages.append({"role": "system", "content": example[system_field]})
|
|
|
+ elif system_prompt:
|
|
|
+ messages.append({"role": "system", "content": system_prompt})
|
|
|
+
|
|
|
+ # User message
|
|
|
+ messages.append(
|
|
|
+ {"role": "user", "content": str(example[instruction_field])}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Assistant response
|
|
|
+ messages.append(
|
|
|
+ {"role": "assistant", "content": str(example[output_field])}
|
|
|
+ )
|
|
|
+
|
|
|
+ text = tokenizer.apply_chat_template(
|
|
|
+ messages, tokenize=False, add_generation_prompt=False
|
|
|
+ )
|
|
|
+ return {"text": text}
|
|
|
|
|
|
ds = ds.map(
|
|
|
format_chat,
|
|
|
@@ -186,6 +247,11 @@ def prepare_dataset(
|
|
|
num_proc=4,
|
|
|
desc="Formatting dataset",
|
|
|
)
|
|
|
+
|
|
|
+ # Filter out empty texts (e.g. conversations with no assistant turn)
|
|
|
+ ds = ds.filter(lambda x: len(x["text"].strip()) > 0)
|
|
|
+ logger.info(f"After filtering: {len(ds)} samples")
|
|
|
+
|
|
|
return ds
|
|
|
|
|
|
|