convert_to_gguf.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. #!/usr/bin/env python3
  2. """
  3. Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format.
  4. Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
  5. Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
  6. """
  7. import os
  8. import argparse
  9. import torch
  10. import numpy as np
  11. import yaml
  12. import librosa
  13. from einops import repeat, reduce, rearrange
  14. import gguf
  15. from gguf.quants import quantize, GGMLQuantizationType
  16. def generate_buffers(hparams):
  17. """
  18. Generate the buffers (freq_indices, num_bands_per_freq, etc.)
  19. mimicking the logic in MelBandRoformer.__init__.
  20. """
  21. num_bands = hparams["num_bands"]
  22. sample_rate = hparams.get("sample_rate", 44100)
  23. stft_n_fft = hparams.get("stft_n_fft", 2048)
  24. stereo = hparams.get("stereo", False)
  25. # 1. Calculate number of frequencies
  26. freqs = stft_n_fft // 2 + 1
  27. # 2. Create Mel Filter Bank
  28. mel_filter_bank_numpy = librosa.filters.mel(
  29. sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
  30. )
  31. mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
  32. # 3. Ensure edge values are positive (required for mask generation)
  33. # The exact value doesn't matter as long as it's > 0
  34. mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6)
  35. mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6)
  36. # 4. Create Masks
  37. freqs_per_band = mel_filter_bank > 0
  38. assert freqs_per_band.any(dim=0).all(), (
  39. "all frequencies need to be covered by all bands"
  40. )
  41. # 5. Generate Indices
  42. repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
  43. freq_indices = repeated_freq_indices[freqs_per_band]
  44. if stereo:
  45. freq_indices = repeat(freq_indices, "f -> f s", s=2)
  46. # s=0 -> 2*f, s=1 -> 2*f+1
  47. freq_indices = freq_indices * 2 + torch.arange(2)
  48. freq_indices = rearrange(freq_indices, "f s -> (f s)")
  49. # 6. Aggregate Counts
  50. num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
  51. num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
  52. return {
  53. "freq_indices": freq_indices,
  54. "num_freqs_per_band": num_freqs_per_band,
  55. "num_bands_per_freq": num_bands_per_freq,
  56. "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved
  57. }
  58. # ============================================================================
  59. # Quantization Helper
  60. # ============================================================================
  61. def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
  62. mapping = {
  63. "f32": GGMLQuantizationType.F32,
  64. "fp32": GGMLQuantizationType.F32,
  65. "f16": GGMLQuantizationType.F16,
  66. "fp16": GGMLQuantizationType.F16,
  67. "q8_0": GGMLQuantizationType.Q8_0,
  68. "q4_0": GGMLQuantizationType.Q4_0,
  69. "q4_1": GGMLQuantizationType.Q4_1,
  70. "q5_0": GGMLQuantizationType.Q5_0,
  71. "q5_1": GGMLQuantizationType.Q5_1,
  72. }
  73. return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32)
  74. def get_file_type_id(qtype: GGMLQuantizationType) -> int:
  75. # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
  76. mapping = {
  77. GGMLQuantizationType.F32: 0,
  78. GGMLQuantizationType.F16: 1,
  79. GGMLQuantizationType.Q4_0: 2,
  80. GGMLQuantizationType.Q4_1: 3,
  81. # 4 is Q4_1_O (deprecated/legacy?)
  82. # 5 is Q4_0_O ?
  83. # 6 is Q4_1_O ?
  84. GGMLQuantizationType.Q8_0: 7,
  85. GGMLQuantizationType.Q5_0: 8,
  86. GGMLQuantizationType.Q5_1: 9,
  87. GGMLQuantizationType.Q2_K: 10,
  88. GGMLQuantizationType.Q3_K: 11,
  89. GGMLQuantizationType.Q4_K: 12,
  90. GGMLQuantizationType.Q5_K: 13,
  91. GGMLQuantizationType.Q6_K: 14,
  92. # IQ2_XXS etc might have IDs but let's stick to these for now
  93. }
  94. return mapping.get(qtype, 0) # Default to ALL_F32 if unknown
  95. def should_quantize(name: str) -> bool:
  96. """
  97. Determine if a tensor should be quantized.
  98. Keep norms and biases as FP32 to avoid CUDA alignment issues.
  99. """
  100. # Biases are always small and sensitive
  101. if "bias" in name:
  102. return False
  103. # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA
  104. if "norm.weight" in name:
  105. return False
  106. # Quantize all other "weight" matrices (Linear, Conv, Embedding if any)
  107. if "weight" in name:
  108. return True
  109. return False
  110. # ============================================================================
  111. # Key Name Mapping
  112. # ============================================================================
  113. def map_key_name(key: str) -> str:
  114. """
  115. Map PyTorch state_dict keys to GGUF format (blk.{bid}.*).
  116. Standardizes suffixes: gamma -> weight, beta -> bias.
  117. """
  118. def standardize_suffix(param_name: str) -> str:
  119. if param_name == "gamma":
  120. return "weight"
  121. if param_name == "beta":
  122. return "bias"
  123. return param_name
  124. parts = key.split(".")
  125. suffix = standardize_suffix(parts[-1])
  126. # Transformer Layers
  127. if key.startswith("layers."):
  128. layer_idx = parts[1]
  129. tf_idx = parts[2] # 0=Time, 1=Freq
  130. type_str = "time" if tf_idx == "0" else "freq"
  131. # Final Norm: layers.0.0.norm.gamma
  132. if len(parts) >= 5 and parts[3] == "norm":
  133. return f"blk.{layer_idx}.{type_str}_norm.{suffix}"
  134. # Sub-layers (Attention=0, FF=1)
  135. if len(parts) >= 6 and parts[3] == "layers":
  136. block_sub_idx = parts[5]
  137. if block_sub_idx == "0": # Attention
  138. if len(parts) > 6:
  139. sub_name = parts[6]
  140. if sub_name == "norm":
  141. return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}"
  142. if sub_name == "to_qkv":
  143. return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}"
  144. if sub_name == "to_out":
  145. return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}"
  146. if sub_name == "to_gates":
  147. return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}"
  148. elif block_sub_idx == "1": # FeedForward
  149. if len(parts) >= 8 and parts[6] == "net":
  150. net_idx = parts[7]
  151. if net_idx == "0":
  152. return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}"
  153. if net_idx == "1":
  154. return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}"
  155. if net_idx == "4":
  156. return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}"
  157. # BandSplit
  158. if key.startswith("band_split.to_features"):
  159. band_idx = parts[2]
  160. layer_idx = parts[3] # 0=Norm, 1=Linear
  161. if layer_idx == "0":
  162. return f"band_split.{band_idx}.norm.{suffix}"
  163. if layer_idx == "1":
  164. return f"band_split.{band_idx}.linear.{suffix}"
  165. # Mask Estimator
  166. if key.startswith("mask_estimators"):
  167. est_idx = parts[1]
  168. freq_idx = parts[3]
  169. layer_idx = parts[5] # 0, 2, 4
  170. return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
  171. return key.replace(".", "_")
  172. # ============================================================================
  173. # Main Conversion
  174. # ============================================================================
  175. def convert(
  176. ckpt_path: str,
  177. output_path: str,
  178. config_path: str,
  179. dtype: str = "fp32",
  180. ):
  181. """
  182. Convert PyTorch checkpoint to GGUF format.
  183. """
  184. print(f"Loading checkpoint: {ckpt_path}")
  185. checkpoint = torch.load(ckpt_path, map_location="cpu")
  186. if "state_dict" in checkpoint:
  187. state_dict = checkpoint["state_dict"]
  188. elif "model" in checkpoint:
  189. state_dict = checkpoint["model"]
  190. else:
  191. state_dict = checkpoint
  192. print(f"Loading config: {config_path}")
  193. with open(config_path) as f:
  194. config_dict = yaml.load(f, Loader=yaml.FullLoader)
  195. # Generate buffers
  196. print("Generating buffers (standalone)...")
  197. buffers = generate_buffers(config_dict["model"])
  198. freq_indices = buffers["freq_indices"]
  199. num_bands_per_freq = buffers["num_bands_per_freq"]
  200. num_freqs_per_band = buffers["num_freqs_per_band"]
  201. # Create GGUF writer
  202. gguf_writer = gguf.GGUFWriter(output_path, "mel_band_roformer")
  203. # =========================================================================
  204. # 1. Write Standard GGUF Metadata
  205. # =========================================================================
  206. print("Writing metadata...")
  207. # General metadata
  208. gguf_writer.add_name("Mel-Band-Roformer Vocal Separator")
  209. gguf_writer.add_description("Audio source separation model for vocal extraction")
  210. # Determine types
  211. target_qtype = get_target_quantization_type(dtype)
  212. file_type_id = get_file_type_id(target_qtype)
  213. gguf_writer.add_file_type(file_type_id)
  214. # Quantization version (required when quantized)
  215. if target_qtype != GGMLQuantizationType.F32:
  216. gguf_writer.add_quantization_version(2)
  217. # Calculate parameter count
  218. total_params = 0
  219. for key, tensor in state_dict.items():
  220. if "freq_indices" in key or "num_bands" in key:
  221. continue
  222. total_params += tensor.numel()
  223. print(f"Total parameters: {total_params}")
  224. gguf_writer.add_uint64("general.parameter_count", total_params)
  225. # =========================================================================
  226. # 2. Write Hyperparameters
  227. # =========================================================================
  228. print("Writing hyperparameters...")
  229. hparams = config_dict["model"]
  230. # Architecture specific parameters
  231. gguf_writer.add_uint32("mel_band_roformer.dim", hparams["dim"])
  232. gguf_writer.add_uint32("mel_band_roformer.depth", hparams["depth"])
  233. gguf_writer.add_uint32("mel_band_roformer.num_bands", hparams["num_bands"])
  234. # STFT parameters
  235. gguf_writer.add_uint32(
  236. "mel_band_roformer.stft_n_fft", hparams.get("stft_n_fft", 2048)
  237. )
  238. # Remove default for hop_length, must be present or fail/warn
  239. gguf_writer.add_uint32(
  240. "mel_band_roformer.stft_hop_length", hparams.get("stft_hop_length", 441)
  241. )
  242. gguf_writer.add_uint32(
  243. "mel_band_roformer.stft_win_length", hparams.get("stft_win_length", 2048)
  244. )
  245. gguf_writer.add_bool(
  246. "mel_band_roformer.stft_normalized", hparams.get("stft_normalized", False)
  247. )
  248. gguf_writer.add_bool(
  249. "mel_band_roformer.zero_dc", hparams.get("zero_dc", True)
  250. ) # Defaults to True in reference implementation
  251. # Architecture details
  252. gguf_writer.add_uint32("mel_band_roformer.num_stems", hparams.get("num_stems", 1))
  253. gguf_writer.add_bool("mel_band_roformer.stereo", hparams.get("stereo", False))
  254. gguf_writer.add_uint32(
  255. "mel_band_roformer.sample_rate", hparams.get("sample_rate", 44100)
  256. )
  257. gguf_writer.add_uint32(
  258. "mel_band_roformer.time_transformer_depth",
  259. hparams.get("time_transformer_depth", 0),
  260. )
  261. gguf_writer.add_uint32(
  262. "mel_band_roformer.freq_transformer_depth",
  263. hparams.get("freq_transformer_depth", 0),
  264. )
  265. gguf_writer.add_uint32(
  266. "mel_band_roformer.linear_transformer_depth",
  267. hparams.get("linear_transformer_depth", 0),
  268. )
  269. gguf_writer.add_uint32(
  270. "mel_band_roformer.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
  271. )
  272. gguf_writer.add_uint32("mel_band_roformer.dim_head", hparams.get("dim_head", 64))
  273. gguf_writer.add_uint32("mel_band_roformer.heads", hparams.get("heads", 8))
  274. gguf_writer.add_uint32(
  275. "mel_band_roformer.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
  276. )
  277. gguf_writer.add_bool(
  278. "mel_band_roformer.skip_connection", hparams.get("skip_connection", False)
  279. )
  280. # =========================================================================
  281. # 3. Write Inference Defaults (Optional, can be overridden at runtime)
  282. # =========================================================================
  283. print("Writing inference defaults...")
  284. inference_config = config_dict.get("inference", {})
  285. audio_config = config_dict.get("audio", {})
  286. # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
  287. default_chunk_size = inference_config.get(
  288. "chunk_size", audio_config.get("chunk_size", 352800)
  289. )
  290. # num_overlap: from inference section
  291. default_num_overlap = inference_config.get("num_overlap", 0)
  292. gguf_writer.add_uint32("mel_band_roformer.default_chunk_size", default_chunk_size)
  293. gguf_writer.add_uint32("mel_band_roformer.default_num_overlap", default_num_overlap)
  294. # =========================================================================
  295. # 4. Write Buffers (Always FP32/I32)
  296. # =========================================================================
  297. print("Writing buffers...")
  298. # freq_indices (int32)
  299. gguf_writer.add_tensor("buffer_freq_indices", freq_indices.numpy().astype(np.int32))
  300. # num_bands_per_freq (int32)
  301. gguf_writer.add_tensor(
  302. "buffer_num_bands_per_freq", num_bands_per_freq.numpy().astype(np.int32)
  303. )
  304. # num_freqs_per_band (int32)
  305. gguf_writer.add_tensor(
  306. "buffer_num_freqs_per_band", num_freqs_per_band.numpy().astype(np.int32)
  307. )
  308. # =========================================================================
  309. # 5. Write Weights (Mixed Quantization)
  310. # =========================================================================
  311. print(f"Writing weights ({dtype} -> {target_qtype.name})...")
  312. print("Strategy: Quantize weights, Keep Norm/Bias as F32")
  313. n_tensors = 0
  314. n_quantized = 0
  315. for key, tensor in state_dict.items():
  316. new_key = map_key_name(key)
  317. # Skip buffers
  318. if (
  319. "freq_indices" in key
  320. or "num_bands_per_freq" in key
  321. or "num_freqs_per_band" in key
  322. ):
  323. continue
  324. data = tensor.numpy().astype(np.float32)
  325. # Decide whether to quantize
  326. is_quantized = False
  327. if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key):
  328. try:
  329. # Use gguf-py built-in quantization
  330. quantized_data = quantize(data, target_qtype)
  331. # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types)
  332. # or float array (for F16)
  333. gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype)
  334. is_quantized = True
  335. n_quantized += 1
  336. except Exception as e:
  337. print(
  338. f"Warning: Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
  339. )
  340. gguf_writer.add_tensor(new_key, data)
  341. else:
  342. # Keep as F32
  343. gguf_writer.add_tensor(new_key, data)
  344. status = target_qtype.name if is_quantized else "F32"
  345. print(f" {new_key:<50} | {str(data.shape):<20} | {status}")
  346. n_tensors += 1
  347. # =========================================================================
  348. # 6. Write File
  349. # =========================================================================
  350. print(f"\nWriting GGUF to {output_path}")
  351. gguf_writer.write_header_to_file()
  352. gguf_writer.write_kv_data_to_file()
  353. gguf_writer.write_tensors_to_file()
  354. gguf_writer.close()
  355. file_size = os.path.getsize(output_path)
  356. print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
  357. print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
  358. if __name__ == "__main__":
  359. parser = argparse.ArgumentParser(
  360. description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization",
  361. formatter_class=argparse.RawDescriptionHelpFormatter,
  362. epilog="""
  363. Examples:
  364. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16
  365. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0
  366. """,
  367. )
  368. parser.add_argument(
  369. "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
  370. )
  371. parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
  372. parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
  373. parser.add_argument(
  374. "--dtype",
  375. type=str,
  376. default="fp32",
  377. choices=[
  378. "fp32",
  379. "f32",
  380. "fp16",
  381. "f16",
  382. "q8_0",
  383. "q4_0",
  384. "q4_1",
  385. "q5_0",
  386. "q5_1",
  387. ],
  388. help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
  389. )
  390. args = parser.parse_args()
  391. convert(args.ckpt, args.out, args.config, args.dtype)