merge_and_push.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. Merge LoRA adapter into base model and push to Hugging Face Hub.
  3. Run this AFTER training completes to create a standalone model.
  4. """
  5. import os
  6. import sys
  7. import json
  8. import yaml
  9. import torch
  10. import logging
  11. from pathlib import Path
  12. from transformers import AutoModelForCausalLM, AutoTokenizer
  13. from peft import PeftModel
  14. from huggingface_hub import HfApi
  15. logging.basicConfig(
  16. level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
  17. )
  18. logger = logging.getLogger(__name__)
  19. STATUS_FILE = "/tmp/training_status.json"
  20. def write_status(status: str, detail: str = "", progress: float = 0.0):
  21. data = {"status": status, "detail": detail, "progress": progress, "metrics": {}}
  22. Path(STATUS_FILE).write_text(json.dumps(data))
  23. def merge_and_push(
  24. adapter_path: str = "/tmp/qwen3-uncensored-lora/final_adapter",
  25. hub_model_id: str = "",
  26. push_to_hub: bool = True,
  27. ):
  28. """
  29. Load the base model, merge the LoRA adapter, and optionally push to Hub.
  30. WARNING: This requires significant RAM/VRAM because the full model must be loaded.
  31. For the 80B MoE model, you'll need ~160GB RAM or ~80GB VRAM to merge in bf16.
  32. """
  33. hf_token = os.environ.get("HF_TOKEN")
  34. if not hf_token:
  35. raise ValueError("HF_TOKEN environment variable is required")
  36. with open("config.yaml") as f:
  37. config = yaml.safe_load(f)
  38. model_name = config["model"]["name"]
  39. # -----------------------------------------------------------------------
  40. # 1. Load base model in bf16
  41. # -----------------------------------------------------------------------
  42. write_status("merging", "Loading base model in bfloat16...", 0.1)
  43. logger.info(f"Loading base model: {model_name}")
  44. model = AutoModelForCausalLM.from_pretrained(
  45. model_name,
  46. torch_dtype=torch.bfloat16,
  47. device_map="auto",
  48. trust_remote_code=True,
  49. token=hf_token,
  50. )
  51. tokenizer = AutoTokenizer.from_pretrained(
  52. model_name,
  53. trust_remote_code=True,
  54. token=hf_token,
  55. )
  56. # -----------------------------------------------------------------------
  57. # 2. Load and merge LoRA adapter
  58. # -----------------------------------------------------------------------
  59. write_status("merging", "Merging LoRA adapter into base model...", 0.4)
  60. logger.info(f"Loading adapter from: {adapter_path}")
  61. model = PeftModel.from_pretrained(model, adapter_path)
  62. model = model.merge_and_unload()
  63. logger.info("LoRA adapter merged successfully")
  64. # -----------------------------------------------------------------------
  65. # 3. Save merged model
  66. # -----------------------------------------------------------------------
  67. output_path = "/tmp/merged_model"
  68. write_status("merging", "Saving merged model...", 0.6)
  69. logger.info(f"Saving merged model to: {output_path}")
  70. model.save_pretrained(output_path, safe_serialization=True, max_shard_size="4GB")
  71. tokenizer.save_pretrained(output_path)
  72. # -----------------------------------------------------------------------
  73. # 4. Push to Hub
  74. # -----------------------------------------------------------------------
  75. if push_to_hub and hub_model_id:
  76. write_status("pushing", f"Pushing merged model to {hub_model_id}...", 0.8)
  77. logger.info(f"Pushing to: {hub_model_id}")
  78. api = HfApi(token=hf_token)
  79. api.create_repo(hub_model_id, exist_ok=True)
  80. api.upload_folder(
  81. folder_path=output_path,
  82. repo_id=hub_model_id,
  83. commit_message="Upload merged Qwen3-Coder-Next uncensored (LoRA merged)",
  84. )
  85. logger.info(f"Model pushed to https://huggingface.co/{hub_model_id}")
  86. # Create model card
  87. model_card = f"""---
  88. license: apache-2.0
  89. base_model: {model_name}
  90. tags:
  91. - qwen3
  92. - uncensored
  93. - fine-tuned
  94. - qlora
  95. - merged
  96. ---
  97. # {hub_model_id.split("/")[-1]}
  98. Fine-tuned and uncensored version of [{model_name}](https://huggingface.co/{model_name}).
  99. ## Training Details
  100. - **Method**: QLoRA 4-bit fine-tuning
  101. - **Base Model**: {model_name} (80B MoE / 3B active parameters)
  102. - **LoRA Rank**: {config["lora"]["r"]}
  103. - **LoRA Alpha**: {config["lora"]["lora_alpha"]}
  104. - **Target Modules**: {", ".join(config["lora"]["target_modules"])}
  105. - **Epochs**: {config["training"]["num_train_epochs"]}
  106. - **Learning Rate**: {config["training"]["learning_rate"]}
  107. - **Max Seq Length**: {config["training"]["max_seq_length"]}
  108. ## Usage
  109. ```python
  110. from transformers import AutoModelForCausalLM, AutoTokenizer
  111. model = AutoModelForCausalLM.from_pretrained("{hub_model_id}", torch_dtype="auto", device_map="auto")
  112. tokenizer = AutoTokenizer.from_pretrained("{hub_model_id}")
  113. messages = [{{"role": "user", "content": "Your prompt here"}}]
  114. text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  115. inputs = tokenizer([text], return_tensors="pt").to(model.device)
  116. outputs = model.generate(**inputs, max_new_tokens=4096)
  117. print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  118. ```
  119. """
  120. api.upload_file(
  121. path_or_fileobj=model_card.encode(),
  122. path_in_repo="README.md",
  123. repo_id=hub_model_id,
  124. commit_message="Add model card",
  125. )
  126. write_status("completed", f"Merge complete! Model at {output_path}", 1.0)
  127. logger.info("Done!")
  128. return output_path
  129. if __name__ == "__main__":
  130. import argparse
  131. parser = argparse.ArgumentParser()
  132. parser.add_argument(
  133. "--adapter-path", default="/tmp/qwen3-uncensored-lora/final_adapter"
  134. )
  135. parser.add_argument(
  136. "--hub-model-id",
  137. required=True,
  138. help="e.g. your-username/qwen3-coder-uncensored-merged",
  139. )
  140. parser.add_argument("--no-push", action="store_true")
  141. args = parser.parse_args()
  142. merge_and_push(
  143. adapter_path=args.adapter_path,
  144. hub_model_id=args.hub_model_id,
  145. push_to_hub=not args.no_push,
  146. )