convert_to_gguf.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. #!/usr/bin/env python3
  2. """
  3. Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format.
  4. Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
  5. Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
  6. """
  7. import os
  8. import argparse
  9. import torch
  10. import numpy as np
  11. import yaml
  12. import librosa
  13. from einops import repeat, reduce, rearrange
  14. import gguf
  15. from gguf.quants import quantize, GGMLQuantizationType
  16. def detect_architecture(config_dict):
  17. """
  18. Detect architecture from config.
  19. Returns: 'bs_roformer' or 'mel_band_roformer'
  20. """
  21. # Check structural signatures in 'model' section
  22. model_config = config_dict.get("model", {})
  23. has_freqs = "freqs_per_bands" in model_config
  24. has_num_bands = "num_bands" in model_config
  25. if has_freqs:
  26. return "bs_roformer"
  27. if has_num_bands:
  28. return "mel_band_roformer"
  29. # 3. If neither found, fail
  30. raise ValueError(
  31. "Auto-detection failed: Config missing 'freqs_per_bands' (BS) or 'num_bands' (Mel-Band). "
  32. "Please specify --arch manually."
  33. )
  34. def normalize_arch(arch: str) -> str:
  35. """Normalize architecture name to full GGUF name."""
  36. mapping = {
  37. "bs": "bs_roformer",
  38. "bs_roformer": "bs_roformer",
  39. "mel": "mel_band_roformer",
  40. "mel_band": "mel_band_roformer",
  41. "mel_band_roformer": "mel_band_roformer",
  42. }
  43. result = mapping.get(arch.lower())
  44. if result is None:
  45. raise ValueError(
  46. f"Unknown architecture: '{arch}'. Supported: {list(mapping.keys())}"
  47. )
  48. return result
  49. def generate_buffers_bs(hparams):
  50. """BS Roformer: 从 freqs_per_bands 元组生成缓冲区"""
  51. # Default from bs_roformer.py
  52. DEFAULT_FREQS_PER_BANDS = (
  53. 2,
  54. 2,
  55. 2,
  56. 2,
  57. 2,
  58. 2,
  59. 2,
  60. 2,
  61. 2,
  62. 2,
  63. 2,
  64. 2,
  65. 2,
  66. 2,
  67. 2,
  68. 2,
  69. 2,
  70. 2,
  71. 2,
  72. 2,
  73. 2,
  74. 2,
  75. 2,
  76. 2,
  77. 4,
  78. 4,
  79. 4,
  80. 4,
  81. 4,
  82. 4,
  83. 4,
  84. 4,
  85. 4,
  86. 4,
  87. 4,
  88. 4,
  89. 12,
  90. 12,
  91. 12,
  92. 12,
  93. 12,
  94. 12,
  95. 12,
  96. 12,
  97. 24,
  98. 24,
  99. 24,
  100. 24,
  101. 24,
  102. 24,
  103. 24,
  104. 24,
  105. 48,
  106. 48,
  107. 48,
  108. 48,
  109. 48,
  110. 48,
  111. 48,
  112. 48,
  113. 128,
  114. 129,
  115. )
  116. freqs_per_bands = hparams.get("freqs_per_bands", DEFAULT_FREQS_PER_BANDS)
  117. stereo = hparams.get("stereo", False)
  118. audio_channels = 2 if stereo else 1
  119. # Validate
  120. stft_n_fft = hparams.get("stft_n_fft", 2048)
  121. expected_freqs = stft_n_fft // 2 + 1
  122. # Check sum
  123. sum_freqs = sum(freqs_per_bands)
  124. if sum_freqs != expected_freqs:
  125. print(
  126. f"[WARNING] sum(freqs_per_bands)={sum_freqs} != expected {expected_freqs}. Adjusting last band..."
  127. )
  128. # Note: In C++ logic relying on exact match might be strict, but let's warn for now.
  129. # Actually BS Roformer paper/code implies strict match for STFT reconstruction.
  130. num_bands = len(freqs_per_bands)
  131. freqs_per_bands_with_complex = tuple(
  132. 2 * f * audio_channels for f in freqs_per_bands
  133. )
  134. # num_freqs_per_band: i32 array
  135. num_freqs_per_band = np.array(freqs_per_bands, dtype=np.int32)
  136. # BS doesn't use freq_indices re-indexing, but to keep compatible file structure
  137. # we create dummy full-range indices.
  138. total_freqs_stereo = expected_freqs * audio_channels
  139. freq_indices = np.arange(total_freqs_stereo, dtype=np.int32)
  140. num_bands_per_freq = np.ones(expected_freqs, dtype=np.int32)
  141. print(f"Generated BS buffers: {num_bands} bands, {len(freq_indices)} indices")
  142. return {
  143. "freq_indices": freq_indices,
  144. "num_freqs_per_band": num_freqs_per_band,
  145. "num_bands_per_freq": num_bands_per_freq,
  146. "num_bands": num_bands,
  147. "freqs_per_bands_with_complex": freqs_per_bands_with_complex,
  148. "freqs_per_bands_tuple": freqs_per_bands, # Keep raw tuple for metadata
  149. }
  150. def generate_buffers(hparams, arch="mel_band_roformer"):
  151. """
  152. Generate buffers for the specified architecture.
  153. Args:
  154. hparams: Model hyperparameters
  155. arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
  156. """
  157. if arch == "bs_roformer":
  158. return generate_buffers_bs(hparams)
  159. # Mel-Band-Roformer Logic
  160. # ------------------------------------------------------------------------
  161. """
  162. Generate the buffers (freq_indices, num_bands_per_freq, etc.)
  163. mimicking the logic in MelBandRoformer.__init__.
  164. """
  165. num_bands = hparams["num_bands"]
  166. sample_rate = hparams.get("sample_rate", 44100)
  167. stft_n_fft = hparams.get("stft_n_fft", 2048)
  168. stereo = hparams.get("stereo", False)
  169. # 1. Calculate number of frequencies
  170. freqs = stft_n_fft // 2 + 1
  171. # 2. Create Mel Filter Bank
  172. mel_filter_bank_numpy = librosa.filters.mel(
  173. sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
  174. )
  175. mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
  176. # 3. Ensure edge values are positive (required for mask generation)
  177. # The exact value doesn't matter as long as it's > 0
  178. mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6)
  179. mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6)
  180. # 4. Create Masks
  181. freqs_per_band = mel_filter_bank > 0
  182. assert freqs_per_band.any(dim=0).all(), (
  183. "all frequencies need to be covered by all bands"
  184. )
  185. # 5. Generate Indices
  186. repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
  187. freq_indices = repeated_freq_indices[freqs_per_band]
  188. if stereo:
  189. freq_indices = repeat(freq_indices, "f -> f s", s=2)
  190. # s=0 -> 2*f, s=1 -> 2*f+1
  191. freq_indices = freq_indices * 2 + torch.arange(2)
  192. freq_indices = rearrange(freq_indices, "f s -> (f s)")
  193. # 6. Aggregate Counts
  194. num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
  195. num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
  196. return {
  197. "freq_indices": freq_indices,
  198. "num_freqs_per_band": num_freqs_per_band,
  199. "num_bands_per_freq": num_bands_per_freq,
  200. "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved
  201. }
  202. # ============================================================================
  203. # Quantization Helper
  204. # ============================================================================
  205. def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
  206. mapping = {
  207. "f32": GGMLQuantizationType.F32,
  208. "fp32": GGMLQuantizationType.F32,
  209. "f16": GGMLQuantizationType.F16,
  210. "fp16": GGMLQuantizationType.F16,
  211. "q8_0": GGMLQuantizationType.Q8_0,
  212. "q4_0": GGMLQuantizationType.Q4_0,
  213. "q4_1": GGMLQuantizationType.Q4_1,
  214. "q5_0": GGMLQuantizationType.Q5_0,
  215. "q5_1": GGMLQuantizationType.Q5_1,
  216. }
  217. return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32)
  218. def get_file_type_id(qtype: GGMLQuantizationType) -> int:
  219. # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
  220. mapping = {
  221. GGMLQuantizationType.F32: 0,
  222. GGMLQuantizationType.F16: 1,
  223. GGMLQuantizationType.Q4_0: 2,
  224. GGMLQuantizationType.Q4_1: 3,
  225. # 4 is Q4_1_O (deprecated/legacy?)
  226. # 5 is Q4_0_O ?
  227. # 6 is Q4_1_O ?
  228. GGMLQuantizationType.Q8_0: 7,
  229. GGMLQuantizationType.Q5_0: 8,
  230. GGMLQuantizationType.Q5_1: 9,
  231. GGMLQuantizationType.Q2_K: 10,
  232. GGMLQuantizationType.Q3_K: 11,
  233. GGMLQuantizationType.Q4_K: 12,
  234. GGMLQuantizationType.Q5_K: 13,
  235. GGMLQuantizationType.Q6_K: 14,
  236. # IQ2_XXS etc might have IDs but let's stick to these for now
  237. }
  238. return mapping.get(qtype, 0) # Default to ALL_F32 if unknown
  239. def should_quantize(name: str) -> bool:
  240. """
  241. Determine if a tensor should be quantized.
  242. Keep norms and biases as FP32 to avoid CUDA alignment issues.
  243. """
  244. # Biases are always small and sensitive
  245. if "bias" in name:
  246. return False
  247. # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA
  248. if "norm.weight" in name:
  249. return False
  250. # Quantize all other "weight" matrices (Linear, Conv, Embedding if any)
  251. if "weight" in name:
  252. return True
  253. return False
  254. # ============================================================================
  255. # Key Name Mapping
  256. # ============================================================================
  257. def map_key_name(key: str) -> str:
  258. """
  259. Map PyTorch state_dict keys to GGUF format (blk.{bid}.*).
  260. Standardizes suffixes: gamma -> weight, beta -> bias.
  261. """
  262. def standardize_suffix(param_name: str) -> str:
  263. if param_name == "gamma":
  264. return "weight"
  265. if param_name == "beta":
  266. return "bias"
  267. return param_name
  268. parts = key.split(".")
  269. suffix = standardize_suffix(parts[-1])
  270. # Transformer Layers
  271. if key.startswith("layers."):
  272. layer_idx = parts[1]
  273. tf_idx = parts[2] # 0=Time, 1=Freq
  274. type_str = "time" if tf_idx == "0" else "freq"
  275. # Final Norm: layers.0.0.norm.gamma
  276. if len(parts) >= 5 and parts[3] == "norm":
  277. return f"blk.{layer_idx}.{type_str}_norm.{suffix}"
  278. # Sub-layers (Attention=0, FF=1)
  279. if len(parts) >= 6 and parts[3] == "layers":
  280. block_sub_idx = parts[5]
  281. if block_sub_idx == "0": # Attention
  282. if len(parts) > 6:
  283. sub_name = parts[6]
  284. if sub_name == "norm":
  285. return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}"
  286. if sub_name == "to_qkv":
  287. return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}"
  288. if sub_name == "to_out":
  289. return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}"
  290. if sub_name == "to_gates":
  291. return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}"
  292. elif block_sub_idx == "1": # FeedForward
  293. if len(parts) >= 8 and parts[6] == "net":
  294. net_idx = parts[7]
  295. if net_idx == "0":
  296. return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}"
  297. if net_idx == "1":
  298. return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}"
  299. if net_idx == "4":
  300. return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}"
  301. # BandSplit
  302. if key.startswith("band_split.to_features"):
  303. band_idx = parts[2]
  304. layer_idx = parts[3] # 0=Norm, 1=Linear
  305. if layer_idx == "0":
  306. return f"band_split.{band_idx}.norm.{suffix}"
  307. if layer_idx == "1":
  308. return f"band_split.{band_idx}.linear.{suffix}"
  309. # Mask Estimator
  310. if key.startswith("mask_estimators"):
  311. est_idx = parts[1]
  312. freq_idx = parts[3]
  313. layer_idx = parts[5] # 0, 2, 4
  314. return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
  315. # Final Norm
  316. if key.startswith("final_norm"):
  317. return f"final_norm.{suffix}"
  318. return key.replace(".", "_")
  319. # ============================================================================
  320. # Main Conversion
  321. # ============================================================================
  322. def convert(
  323. ckpt_path: str,
  324. output_path: str,
  325. config_path: str,
  326. dtype: str = "fp32",
  327. name: str | None = None,
  328. description: str | None = None,
  329. arch: str | None = None,
  330. ):
  331. """
  332. Convert PyTorch checkpoint to GGUF format.
  333. """
  334. print(f"Loading checkpoint: {ckpt_path}")
  335. checkpoint = torch.load(ckpt_path, map_location="cpu")
  336. if "state_dict" in checkpoint:
  337. state_dict = checkpoint["state_dict"]
  338. elif "model" in checkpoint:
  339. state_dict = checkpoint["model"]
  340. else:
  341. state_dict = checkpoint
  342. print(f"Loading config: {config_path}")
  343. with open(config_path) as f:
  344. config_dict = yaml.load(f, Loader=yaml.FullLoader)
  345. # Detect architecture
  346. if arch is None:
  347. try:
  348. arch = detect_architecture(config_dict)
  349. print(f"Auto-detected architecture: {arch}")
  350. except ValueError as e:
  351. print(f"Error: {e}")
  352. return
  353. else:
  354. # Normalize provided arch to full name
  355. arch = normalize_arch(arch)
  356. # Generate buffers
  357. print("Generating buffers (standalone)...")
  358. buffers = generate_buffers(config_dict["model"], arch=arch)
  359. freq_indices = buffers["freq_indices"]
  360. num_bands_per_freq = buffers["num_bands_per_freq"]
  361. num_freqs_per_band = buffers["num_freqs_per_band"]
  362. arch_name = arch
  363. # Create GGUF writer
  364. gguf_writer = gguf.GGUFWriter(output_path, arch_name)
  365. # =========================================================================
  366. # 1. Write Standard GGUF Metadata
  367. # =========================================================================
  368. print("Writing metadata...")
  369. # General metadata
  370. model_name = name if name else "Mel-Band-Roformer Separator"
  371. model_description = description if description else "Music source separation model"
  372. gguf_writer.add_name(model_name)
  373. gguf_writer.add_description(model_description)
  374. # Determine types
  375. target_qtype = get_target_quantization_type(dtype)
  376. file_type_id = get_file_type_id(target_qtype)
  377. gguf_writer.add_file_type(file_type_id)
  378. # Write Architecture
  379. # gguf_writer.add_string(f"{arch_name}.architecture", arch) # Redundant with general.architecture
  380. if arch_name == "bs_roformer" and "freqs_per_bands_tuple" in buffers:
  381. freqs_tuple = buffers["freqs_per_bands_tuple"]
  382. # Must be list for GGUFWriter
  383. gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
  384. # Quantization version (required when quantized)
  385. if target_qtype != GGMLQuantizationType.F32:
  386. gguf_writer.add_quantization_version(2)
  387. # Calculate parameter count
  388. total_params = 0
  389. for key, tensor in state_dict.items():
  390. if "freq_indices" in key or "num_bands" in key:
  391. continue
  392. total_params += tensor.numel()
  393. print(f"Total parameters: {total_params}")
  394. gguf_writer.add_uint64("general.parameter_count", total_params)
  395. # =========================================================================
  396. # 2. Write Hyperparameters
  397. # =========================================================================
  398. print("Writing hyperparameters...")
  399. hparams = config_dict["model"]
  400. # Load state dict directly (no model class dependency)
  401. print(f"Loading checkpoint for architecture: {arch}")
  402. raw_state_dict = None
  403. if "state_dict" in checkpoint:
  404. raw_state_dict = checkpoint["state_dict"]
  405. elif "model" in checkpoint:
  406. raw_state_dict = checkpoint["model"]
  407. else:
  408. raw_state_dict = checkpoint
  409. if raw_state_dict is None:
  410. raise ValueError("Could not find state_dict in checkpoint")
  411. # Clean up state dict (handle DDP "module." prefix)
  412. state_dict = {}
  413. for k, v in raw_state_dict.items():
  414. if k.startswith("module."):
  415. k = k[7:]
  416. state_dict[k] = v
  417. # Architecture specific parameters
  418. gguf_writer.add_uint32(f"{arch_name}.dim", hparams["dim"])
  419. gguf_writer.add_uint32(f"{arch_name}.depth", hparams["depth"])
  420. # BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
  421. num_bands = buffers.get("num_bands", hparams.get("num_bands", 60))
  422. gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
  423. # STFT parameters
  424. gguf_writer.add_uint32(f"{arch_name}.stft_n_fft", hparams.get("stft_n_fft", 2048))
  425. # Remove default for hop_length, must be present or fail/warn
  426. gguf_writer.add_uint32(
  427. f"{arch_name}.stft_hop_length", hparams.get("stft_hop_length", 441)
  428. )
  429. gguf_writer.add_uint32(
  430. f"{arch_name}.stft_win_length", hparams.get("stft_win_length", 2048)
  431. )
  432. gguf_writer.add_bool(
  433. f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
  434. )
  435. gguf_writer.add_bool(
  436. f"{arch_name}.zero_dc", hparams.get("zero_dc", True)
  437. ) # Defaults to True in reference implementation
  438. # Architecture details
  439. gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
  440. gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
  441. gguf_writer.add_uint32(
  442. f"{arch_name}.sample_rate", hparams.get("sample_rate", 44100)
  443. )
  444. gguf_writer.add_uint32(
  445. f"{arch_name}.time_transformer_depth",
  446. hparams.get("time_transformer_depth", 0),
  447. )
  448. gguf_writer.add_uint32(
  449. f"{arch_name}.freq_transformer_depth",
  450. hparams.get("freq_transformer_depth", 0),
  451. )
  452. gguf_writer.add_uint32(
  453. f"{arch_name}.linear_transformer_depth",
  454. hparams.get("linear_transformer_depth", 0),
  455. )
  456. gguf_writer.add_uint32(
  457. f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
  458. )
  459. gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("dim_head", 64))
  460. gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("heads", 8))
  461. gguf_writer.add_uint32(
  462. f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
  463. )
  464. gguf_writer.add_bool(
  465. f"{arch_name}.skip_connection", hparams.get("skip_connection", False)
  466. )
  467. # =========================================================================
  468. # 3. Write Inference Defaults (Optional, can be overridden at runtime)
  469. # =========================================================================
  470. print("Writing inference defaults...")
  471. inference_config = config_dict.get("inference", {})
  472. audio_config = config_dict.get("audio", {})
  473. # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
  474. default_chunk_size = inference_config.get(
  475. "chunk_size", audio_config.get("chunk_size", 352800)
  476. )
  477. # num_overlap: from inference section
  478. default_num_overlap = inference_config.get("num_overlap", 0)
  479. gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
  480. gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
  481. # =========================================================================
  482. # 4. Write Buffers (Always FP32/I32)
  483. # =========================================================================
  484. print("Writing buffers...")
  485. # freq_indices (int32) - may be torch.Tensor (MelBand) or np.ndarray (BS)
  486. fi = freq_indices.numpy() if hasattr(freq_indices, "numpy") else freq_indices
  487. gguf_writer.add_tensor("buffer_freq_indices", fi.astype(np.int32))
  488. # num_bands_per_freq (int32)
  489. nbpf = (
  490. num_bands_per_freq.numpy()
  491. if hasattr(num_bands_per_freq, "numpy")
  492. else num_bands_per_freq
  493. )
  494. gguf_writer.add_tensor("buffer_num_bands_per_freq", nbpf.astype(np.int32))
  495. # num_freqs_per_band (int32)
  496. nfpb = (
  497. num_freqs_per_band.numpy()
  498. if hasattr(num_freqs_per_band, "numpy")
  499. else num_freqs_per_band
  500. )
  501. gguf_writer.add_tensor("buffer_num_freqs_per_band", nfpb.astype(np.int32))
  502. # =========================================================================
  503. # 5. Write Weights (Mixed Quantization)
  504. # =========================================================================
  505. print(f"Writing weights ({dtype} -> {target_qtype.name})...")
  506. print("Strategy: Quantize weights, Keep Norm/Bias as F32")
  507. n_tensors = 0
  508. n_quantized = 0
  509. warnings_list = []
  510. for key, tensor in state_dict.items():
  511. new_key = map_key_name(key)
  512. # Skip buffers
  513. if (
  514. "freq_indices" in key
  515. or "num_bands_per_freq" in key
  516. or "num_freqs_per_band" in key
  517. ):
  518. continue
  519. data = tensor.numpy().astype(np.float32)
  520. # Decide whether to quantize
  521. is_quantized = False
  522. if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key):
  523. try:
  524. # Use gguf-py built-in quantization
  525. quantized_data = quantize(data, target_qtype)
  526. # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types)
  527. # or float array (for F16)
  528. gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype)
  529. is_quantized = True
  530. n_quantized += 1
  531. except Exception as e:
  532. msg = f"Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
  533. warnings_list.append(msg)
  534. gguf_writer.add_tensor(new_key, data)
  535. else:
  536. # Keep as F32
  537. gguf_writer.add_tensor(new_key, data)
  538. status = target_qtype.name if is_quantized else "F32"
  539. print(f" {new_key:<50} | {str(data.shape):<20} | {status}")
  540. n_tensors += 1
  541. # =========================================================================
  542. # 6. Write File
  543. # =========================================================================
  544. print(f"\nWriting GGUF to {output_path}")
  545. gguf_writer.write_header_to_file()
  546. gguf_writer.write_kv_data_to_file()
  547. gguf_writer.write_tensors_to_file()
  548. gguf_writer.close()
  549. if warnings_list:
  550. print("\n" + "=" * 80)
  551. print(
  552. f"WARNING: {len(warnings_list)} tensors failed to quantize (fallback to F32):"
  553. )
  554. for msg in warnings_list:
  555. print(f" - {msg}")
  556. print("=" * 80 + "\n")
  557. file_size = os.path.getsize(output_path)
  558. print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
  559. print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
  560. if __name__ == "__main__":
  561. parser = argparse.ArgumentParser(
  562. description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization",
  563. formatter_class=argparse.RawDescriptionHelpFormatter,
  564. epilog="""
  565. Examples:
  566. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16
  567. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0
  568. """,
  569. )
  570. parser.add_argument(
  571. "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
  572. )
  573. parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
  574. parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
  575. parser.add_argument(
  576. "--dtype",
  577. type=str,
  578. default="fp32",
  579. choices=[
  580. "fp32",
  581. "f32",
  582. "fp16",
  583. "f16",
  584. "q8_0",
  585. "q4_0",
  586. "q4_1",
  587. "q5_0",
  588. "q5_1",
  589. ],
  590. help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
  591. )
  592. parser.add_argument(
  593. "--name",
  594. type=str,
  595. default=None,
  596. help="Model name (default: 'Mel-Band-Roformer Vocal Separator')",
  597. )
  598. parser.add_argument(
  599. "--description",
  600. type=str,
  601. default=None,
  602. help="Model description (default: 'Audio source separation model for vocal extraction')",
  603. )
  604. parser.add_argument(
  605. "--arch",
  606. choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer"],
  607. default=None,
  608. help="Architecture type (auto-detected if not specified)",
  609. )
  610. args = parser.parse_args()
  611. convert(
  612. args.ckpt,
  613. args.out,
  614. args.config,
  615. args.dtype,
  616. args.name,
  617. args.description,
  618. args.arch,
  619. )