convert_to_gguf.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #!/usr/bin/env python3
  2. """
  3. Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format.
  4. Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
  5. Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
  6. """
  7. import os
  8. import argparse
  9. import torch
  10. import numpy as np
  11. import yaml
  12. import librosa
  13. from einops import repeat, reduce, rearrange
  14. import gguf
  15. from gguf.quants import quantize, GGMLQuantizationType
  16. def generate_buffers(hparams):
  17. """
  18. Generate the buffers (freq_indices, num_bands_per_freq, etc.)
  19. mimicking the logic in MelBandRoformer.__init__.
  20. """
  21. num_bands = hparams["num_bands"]
  22. sample_rate = hparams.get("sample_rate", 44100)
  23. stft_n_fft = hparams.get("stft_n_fft", 2048)
  24. stereo = hparams.get("stereo", False)
  25. # 1. Calculate number of frequencies
  26. freqs = stft_n_fft // 2 + 1
  27. # 2. Create Mel Filter Bank
  28. mel_filter_bank_numpy = librosa.filters.mel(
  29. sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
  30. )
  31. mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
  32. # 3. Ensure edge values are positive (required for mask generation)
  33. # The exact value doesn't matter as long as it's > 0
  34. mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6)
  35. mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6)
  36. # 4. Create Masks
  37. freqs_per_band = mel_filter_bank > 0
  38. assert freqs_per_band.any(dim=0).all(), (
  39. "all frequencies need to be covered by all bands"
  40. )
  41. # 5. Generate Indices
  42. repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
  43. freq_indices = repeated_freq_indices[freqs_per_band]
  44. if stereo:
  45. freq_indices = repeat(freq_indices, "f -> f s", s=2)
  46. # s=0 -> 2*f, s=1 -> 2*f+1
  47. freq_indices = freq_indices * 2 + torch.arange(2)
  48. freq_indices = rearrange(freq_indices, "f s -> (f s)")
  49. # 6. Aggregate Counts
  50. num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
  51. num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
  52. return {
  53. "freq_indices": freq_indices,
  54. "num_freqs_per_band": num_freqs_per_band,
  55. "num_bands_per_freq": num_bands_per_freq,
  56. "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved
  57. }
  58. # ============================================================================
  59. # Quantization Helper
  60. # ============================================================================
  61. def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
  62. mapping = {
  63. "f32": GGMLQuantizationType.F32,
  64. "fp32": GGMLQuantizationType.F32,
  65. "f16": GGMLQuantizationType.F16,
  66. "fp16": GGMLQuantizationType.F16,
  67. "q8_0": GGMLQuantizationType.Q8_0,
  68. "q4_0": GGMLQuantizationType.Q4_0,
  69. "q4_1": GGMLQuantizationType.Q4_1,
  70. "q5_0": GGMLQuantizationType.Q5_0,
  71. "q5_1": GGMLQuantizationType.Q5_1,
  72. }
  73. return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32)
  74. def get_file_type_id(qtype: GGMLQuantizationType) -> int:
  75. # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
  76. mapping = {
  77. GGMLQuantizationType.F32: 0,
  78. GGMLQuantizationType.F16: 1,
  79. GGMLQuantizationType.Q4_0: 2,
  80. GGMLQuantizationType.Q4_1: 3,
  81. # 4 is Q4_1_O (deprecated/legacy?)
  82. # 5 is Q4_0_O ?
  83. # 6 is Q4_1_O ?
  84. GGMLQuantizationType.Q8_0: 7,
  85. GGMLQuantizationType.Q5_0: 8,
  86. GGMLQuantizationType.Q5_1: 9,
  87. GGMLQuantizationType.Q2_K: 10,
  88. GGMLQuantizationType.Q3_K: 11,
  89. GGMLQuantizationType.Q4_K: 12,
  90. GGMLQuantizationType.Q5_K: 13,
  91. GGMLQuantizationType.Q6_K: 14,
  92. # IQ2_XXS etc might have IDs but let's stick to these for now
  93. }
  94. return mapping.get(qtype, 0) # Default to ALL_F32 if unknown
  95. def should_quantize(name: str) -> bool:
  96. """
  97. Determine if a tensor should be quantized.
  98. Keep norms and biases as FP32 to avoid CUDA alignment issues.
  99. """
  100. # Biases are always small and sensitive
  101. if "bias" in name:
  102. return False
  103. # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA
  104. if "norm.weight" in name:
  105. return False
  106. # Quantize all other "weight" matrices (Linear, Conv, Embedding if any)
  107. if "weight" in name:
  108. return True
  109. return False
  110. # ============================================================================
  111. # Key Name Mapping
  112. # ============================================================================
  113. def map_key_name(key: str) -> str:
  114. """
  115. Map PyTorch state_dict keys to GGUF format (blk.{bid}.*).
  116. Standardizes suffixes: gamma -> weight, beta -> bias.
  117. """
  118. def standardize_suffix(param_name: str) -> str:
  119. if param_name == "gamma":
  120. return "weight"
  121. if param_name == "beta":
  122. return "bias"
  123. return param_name
  124. parts = key.split(".")
  125. suffix = standardize_suffix(parts[-1])
  126. # Transformer Layers
  127. if key.startswith("layers."):
  128. layer_idx = parts[1]
  129. tf_idx = parts[2] # 0=Time, 1=Freq
  130. type_str = "time" if tf_idx == "0" else "freq"
  131. # Final Norm: layers.0.0.norm.gamma
  132. if len(parts) >= 5 and parts[3] == "norm":
  133. return f"blk.{layer_idx}.{type_str}_norm.{suffix}"
  134. # Sub-layers (Attention=0, FF=1)
  135. if len(parts) >= 6 and parts[3] == "layers":
  136. block_sub_idx = parts[5]
  137. if block_sub_idx == "0": # Attention
  138. if len(parts) > 6:
  139. sub_name = parts[6]
  140. if sub_name == "norm":
  141. return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}"
  142. if sub_name == "to_qkv":
  143. return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}"
  144. if sub_name == "to_out":
  145. return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}"
  146. if sub_name == "to_gates":
  147. return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}"
  148. elif block_sub_idx == "1": # FeedForward
  149. if len(parts) >= 8 and parts[6] == "net":
  150. net_idx = parts[7]
  151. if net_idx == "0":
  152. return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}"
  153. if net_idx == "1":
  154. return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}"
  155. if net_idx == "4":
  156. return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}"
  157. # BandSplit
  158. if key.startswith("band_split.to_features"):
  159. band_idx = parts[2]
  160. layer_idx = parts[3] # 0=Norm, 1=Linear
  161. if layer_idx == "0":
  162. return f"band_split.{band_idx}.norm.{suffix}"
  163. if layer_idx == "1":
  164. return f"band_split.{band_idx}.linear.{suffix}"
  165. # Mask Estimator
  166. if key.startswith("mask_estimators"):
  167. est_idx = parts[1]
  168. freq_idx = parts[3]
  169. layer_idx = parts[5] # 0, 2, 4
  170. return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
  171. return key.replace(".", "_")
  172. # ============================================================================
  173. # Main Conversion
  174. # ============================================================================
  175. def convert(
  176. ckpt_path: str,
  177. output_path: str,
  178. config_path: str,
  179. dtype: str = "fp32",
  180. name: str = None,
  181. description: str = None,
  182. ):
  183. """
  184. Convert PyTorch checkpoint to GGUF format.
  185. """
  186. print(f"Loading checkpoint: {ckpt_path}")
  187. checkpoint = torch.load(ckpt_path, map_location="cpu")
  188. if "state_dict" in checkpoint:
  189. state_dict = checkpoint["state_dict"]
  190. elif "model" in checkpoint:
  191. state_dict = checkpoint["model"]
  192. else:
  193. state_dict = checkpoint
  194. print(f"Loading config: {config_path}")
  195. with open(config_path) as f:
  196. config_dict = yaml.load(f, Loader=yaml.FullLoader)
  197. # Generate buffers
  198. print("Generating buffers (standalone)...")
  199. buffers = generate_buffers(config_dict["model"])
  200. freq_indices = buffers["freq_indices"]
  201. num_bands_per_freq = buffers["num_bands_per_freq"]
  202. num_freqs_per_band = buffers["num_freqs_per_band"]
  203. # Create GGUF writer
  204. gguf_writer = gguf.GGUFWriter(output_path, "mel_band_roformer")
  205. # =========================================================================
  206. # 1. Write Standard GGUF Metadata
  207. # =========================================================================
  208. print("Writing metadata...")
  209. # General metadata
  210. model_name = name if name else "Mel-Band-Roformer Separator"
  211. model_description = description if description else "Music source separation model"
  212. gguf_writer.add_name(model_name)
  213. gguf_writer.add_description(model_description)
  214. # Determine types
  215. target_qtype = get_target_quantization_type(dtype)
  216. file_type_id = get_file_type_id(target_qtype)
  217. gguf_writer.add_file_type(file_type_id)
  218. # Quantization version (required when quantized)
  219. if target_qtype != GGMLQuantizationType.F32:
  220. gguf_writer.add_quantization_version(2)
  221. # Calculate parameter count
  222. total_params = 0
  223. for key, tensor in state_dict.items():
  224. if "freq_indices" in key or "num_bands" in key:
  225. continue
  226. total_params += tensor.numel()
  227. print(f"Total parameters: {total_params}")
  228. gguf_writer.add_uint64("general.parameter_count", total_params)
  229. # =========================================================================
  230. # 2. Write Hyperparameters
  231. # =========================================================================
  232. print("Writing hyperparameters...")
  233. hparams = config_dict["model"]
  234. # Architecture specific parameters
  235. gguf_writer.add_uint32("mel_band_roformer.dim", hparams["dim"])
  236. gguf_writer.add_uint32("mel_band_roformer.depth", hparams["depth"])
  237. gguf_writer.add_uint32("mel_band_roformer.num_bands", hparams["num_bands"])
  238. # STFT parameters
  239. gguf_writer.add_uint32(
  240. "mel_band_roformer.stft_n_fft", hparams.get("stft_n_fft", 2048)
  241. )
  242. # Remove default for hop_length, must be present or fail/warn
  243. gguf_writer.add_uint32(
  244. "mel_band_roformer.stft_hop_length", hparams.get("stft_hop_length", 441)
  245. )
  246. gguf_writer.add_uint32(
  247. "mel_band_roformer.stft_win_length", hparams.get("stft_win_length", 2048)
  248. )
  249. gguf_writer.add_bool(
  250. "mel_band_roformer.stft_normalized", hparams.get("stft_normalized", False)
  251. )
  252. gguf_writer.add_bool(
  253. "mel_band_roformer.zero_dc", hparams.get("zero_dc", True)
  254. ) # Defaults to True in reference implementation
  255. # Architecture details
  256. gguf_writer.add_uint32("mel_band_roformer.num_stems", hparams.get("num_stems", 1))
  257. gguf_writer.add_bool("mel_band_roformer.stereo", hparams.get("stereo", False))
  258. gguf_writer.add_uint32(
  259. "mel_band_roformer.sample_rate", hparams.get("sample_rate", 44100)
  260. )
  261. gguf_writer.add_uint32(
  262. "mel_band_roformer.time_transformer_depth",
  263. hparams.get("time_transformer_depth", 0),
  264. )
  265. gguf_writer.add_uint32(
  266. "mel_band_roformer.freq_transformer_depth",
  267. hparams.get("freq_transformer_depth", 0),
  268. )
  269. gguf_writer.add_uint32(
  270. "mel_band_roformer.linear_transformer_depth",
  271. hparams.get("linear_transformer_depth", 0),
  272. )
  273. gguf_writer.add_uint32(
  274. "mel_band_roformer.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
  275. )
  276. gguf_writer.add_uint32("mel_band_roformer.dim_head", hparams.get("dim_head", 64))
  277. gguf_writer.add_uint32("mel_band_roformer.heads", hparams.get("heads", 8))
  278. gguf_writer.add_uint32(
  279. "mel_band_roformer.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
  280. )
  281. gguf_writer.add_bool(
  282. "mel_band_roformer.skip_connection", hparams.get("skip_connection", False)
  283. )
  284. # =========================================================================
  285. # 3. Write Inference Defaults (Optional, can be overridden at runtime)
  286. # =========================================================================
  287. print("Writing inference defaults...")
  288. inference_config = config_dict.get("inference", {})
  289. audio_config = config_dict.get("audio", {})
  290. # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
  291. default_chunk_size = inference_config.get(
  292. "chunk_size", audio_config.get("chunk_size", 352800)
  293. )
  294. # num_overlap: from inference section
  295. default_num_overlap = inference_config.get("num_overlap", 0)
  296. gguf_writer.add_uint32("mel_band_roformer.default_chunk_size", default_chunk_size)
  297. gguf_writer.add_uint32("mel_band_roformer.default_num_overlap", default_num_overlap)
  298. # =========================================================================
  299. # 4. Write Buffers (Always FP32/I32)
  300. # =========================================================================
  301. print("Writing buffers...")
  302. # freq_indices (int32)
  303. gguf_writer.add_tensor("buffer_freq_indices", freq_indices.numpy().astype(np.int32))
  304. # num_bands_per_freq (int32)
  305. gguf_writer.add_tensor(
  306. "buffer_num_bands_per_freq", num_bands_per_freq.numpy().astype(np.int32)
  307. )
  308. # num_freqs_per_band (int32)
  309. gguf_writer.add_tensor(
  310. "buffer_num_freqs_per_band", num_freqs_per_band.numpy().astype(np.int32)
  311. )
  312. # =========================================================================
  313. # 5. Write Weights (Mixed Quantization)
  314. # =========================================================================
  315. print(f"Writing weights ({dtype} -> {target_qtype.name})...")
  316. print("Strategy: Quantize weights, Keep Norm/Bias as F32")
  317. n_tensors = 0
  318. n_quantized = 0
  319. for key, tensor in state_dict.items():
  320. new_key = map_key_name(key)
  321. # Skip buffers
  322. if (
  323. "freq_indices" in key
  324. or "num_bands_per_freq" in key
  325. or "num_freqs_per_band" in key
  326. ):
  327. continue
  328. data = tensor.numpy().astype(np.float32)
  329. # Decide whether to quantize
  330. is_quantized = False
  331. if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key):
  332. try:
  333. # Use gguf-py built-in quantization
  334. quantized_data = quantize(data, target_qtype)
  335. # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types)
  336. # or float array (for F16)
  337. gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype)
  338. is_quantized = True
  339. n_quantized += 1
  340. except Exception as e:
  341. print(
  342. f"Warning: Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
  343. )
  344. gguf_writer.add_tensor(new_key, data)
  345. else:
  346. # Keep as F32
  347. gguf_writer.add_tensor(new_key, data)
  348. status = target_qtype.name if is_quantized else "F32"
  349. print(f" {new_key:<50} | {str(data.shape):<20} | {status}")
  350. n_tensors += 1
  351. # =========================================================================
  352. # 6. Write File
  353. # =========================================================================
  354. print(f"\nWriting GGUF to {output_path}")
  355. gguf_writer.write_header_to_file()
  356. gguf_writer.write_kv_data_to_file()
  357. gguf_writer.write_tensors_to_file()
  358. gguf_writer.close()
  359. file_size = os.path.getsize(output_path)
  360. print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
  361. print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
  362. if __name__ == "__main__":
  363. parser = argparse.ArgumentParser(
  364. description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization",
  365. formatter_class=argparse.RawDescriptionHelpFormatter,
  366. epilog="""
  367. Examples:
  368. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16
  369. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0
  370. """,
  371. )
  372. parser.add_argument(
  373. "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
  374. )
  375. parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
  376. parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
  377. parser.add_argument(
  378. "--dtype",
  379. type=str,
  380. default="fp32",
  381. choices=[
  382. "fp32",
  383. "f32",
  384. "fp16",
  385. "f16",
  386. "q8_0",
  387. "q4_0",
  388. "q4_1",
  389. "q5_0",
  390. "q5_1",
  391. ],
  392. help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
  393. )
  394. parser.add_argument(
  395. "--name",
  396. type=str,
  397. default=None,
  398. help="Model name (default: 'Mel-Band-Roformer Vocal Separator')",
  399. )
  400. parser.add_argument(
  401. "--description",
  402. type=str,
  403. default=None,
  404. help="Model description (default: 'Audio source separation model for vocal extraction')",
  405. )
  406. args = parser.parse_args()
  407. convert(args.ckpt, args.out, args.config, args.dtype, args.name, args.description)