convert_to_gguf.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. #!/usr/bin/env python3
  2. """
  3. Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format.
  4. Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
  5. Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
  6. """
  7. import argparse
  8. import os
  9. import gguf
  10. import librosa
  11. import numpy as np
  12. import torch
  13. import yaml
  14. from einops import rearrange, reduce, repeat
  15. from gguf.quants import GGMLQuantizationType, quantize
  16. from safetensors.torch import load_file as load_safetensors
  17. def detect_architecture(config_dict):
  18. """
  19. Detect architecture from config.
  20. Returns: 'bs_roformer', 'bs_roformer_v2', or 'mel_band_roformer'
  21. """
  22. model_config = config_dict.get("model", config_dict)
  23. has_freqs = "freqs_per_bands" in model_config
  24. has_freqs_out = "freqs_per_bands_out" in model_config
  25. has_num_bands = "num_bands" in model_config
  26. if has_freqs and has_freqs_out:
  27. return "bs_roformer_v2"
  28. if has_freqs:
  29. return "bs_roformer"
  30. if has_num_bands:
  31. return "mel_band_roformer"
  32. raise ValueError(
  33. "Auto-detection failed: Config missing 'freqs_per_bands'/'freqs_per_bands_out' (BS_V2), 'freqs_per_bands' (BS), or 'num_bands' (Mel-Band). "
  34. "Please specify --arch manually."
  35. )
  36. def normalize_arch(arch: str) -> str:
  37. """Normalize architecture name to full GGUF name."""
  38. mapping = {
  39. "bs": "bs_roformer",
  40. "bs_roformer": "bs_roformer",
  41. "bs_roformer_v2": "bs_roformer_v2",
  42. "mel": "mel_band_roformer",
  43. "mel_band": "mel_band_roformer",
  44. "mel_band_roformer": "mel_band_roformer",
  45. }
  46. result = mapping.get(arch.lower())
  47. if result is None:
  48. raise ValueError(
  49. f"Unknown architecture: '{arch}'. Supported: {list(mapping.keys())}"
  50. )
  51. return result
  52. def generate_buffers_bs(hparams):
  53. """BS Roformer: 从 freqs_per_bands 元组生成缓冲区"""
  54. # Default from bs_roformer.py
  55. DEFAULT_FREQS_PER_BANDS = (
  56. 2,
  57. 2,
  58. 2,
  59. 2,
  60. 2,
  61. 2,
  62. 2,
  63. 2,
  64. 2,
  65. 2,
  66. 2,
  67. 2,
  68. 2,
  69. 2,
  70. 2,
  71. 2,
  72. 2,
  73. 2,
  74. 2,
  75. 2,
  76. 2,
  77. 2,
  78. 2,
  79. 2,
  80. 4,
  81. 4,
  82. 4,
  83. 4,
  84. 4,
  85. 4,
  86. 4,
  87. 4,
  88. 4,
  89. 4,
  90. 4,
  91. 4,
  92. 12,
  93. 12,
  94. 12,
  95. 12,
  96. 12,
  97. 12,
  98. 12,
  99. 12,
  100. 24,
  101. 24,
  102. 24,
  103. 24,
  104. 24,
  105. 24,
  106. 24,
  107. 24,
  108. 48,
  109. 48,
  110. 48,
  111. 48,
  112. 48,
  113. 48,
  114. 48,
  115. 48,
  116. 128,
  117. 129,
  118. )
  119. freqs_per_bands = hparams.get("freqs_per_bands", DEFAULT_FREQS_PER_BANDS)
  120. stereo = hparams.get("stereo", False)
  121. audio_channels = 2 if stereo else 1
  122. # Validate
  123. stft_n_fft = hparams.get("stft_n_fft", 2048)
  124. expected_freqs = stft_n_fft // 2 + 1
  125. # Check sum
  126. sum_freqs = sum(freqs_per_bands)
  127. if sum_freqs != expected_freqs:
  128. print(
  129. f"[WARNING] sum(freqs_per_bands)={sum_freqs} != expected {expected_freqs}. Adjusting last band..."
  130. )
  131. # Note: In C++ logic relying on exact match might be strict, but let's warn for now.
  132. # Actually BS Roformer paper/code implies strict match for STFT reconstruction.
  133. num_bands = len(freqs_per_bands)
  134. freqs_per_bands_with_complex = tuple(
  135. 2 * f * audio_channels for f in freqs_per_bands
  136. )
  137. # num_freqs_per_band: i32 array
  138. num_freqs_per_band = np.array(freqs_per_bands, dtype=np.int32)
  139. # BS doesn't use freq_indices re-indexing, but to keep compatible file structure
  140. # we create dummy full-range indices.
  141. total_freqs_stereo = expected_freqs * audio_channels
  142. freq_indices = np.arange(total_freqs_stereo, dtype=np.int32)
  143. num_bands_per_freq = np.ones(expected_freqs, dtype=np.int32)
  144. print(f"Generated BS buffers: {num_bands} bands, {len(freq_indices)} indices")
  145. return {
  146. "freq_indices": freq_indices,
  147. "num_freqs_per_band": num_freqs_per_band,
  148. "num_bands_per_freq": num_bands_per_freq,
  149. "num_bands": num_bands,
  150. "freqs_per_bands_with_complex": freqs_per_bands_with_complex,
  151. "freqs_per_bands_tuple": freqs_per_bands, # Keep raw tuple for metadata
  152. }
  153. def generate_buffers(hparams, arch="mel_band_roformer"):
  154. """
  155. Generate buffers for the specified architecture.
  156. Args:
  157. hparams: Model hyperparameters
  158. arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
  159. """
  160. if arch == "bs_roformer" or arch == "bs_roformer_v2":
  161. return generate_buffers_bs(hparams)
  162. # Mel-Band-Roformer Logic
  163. # ------------------------------------------------------------------------
  164. """
  165. Generate the buffers (freq_indices, num_bands_per_freq, etc.)
  166. mimicking the logic in MelBandRoformer.__init__.
  167. """
  168. num_bands = hparams["num_bands"]
  169. sample_rate = hparams.get("sample_rate", 44100)
  170. stft_n_fft = hparams.get("stft_n_fft", 2048)
  171. stereo = hparams.get("stereo", False)
  172. # 1. Calculate number of frequencies
  173. freqs = stft_n_fft // 2 + 1
  174. # 2. Create Mel Filter Bank
  175. mel_filter_bank_numpy = librosa.filters.mel(
  176. sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
  177. )
  178. mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
  179. # 3. Ensure edge values are positive (required for mask generation)
  180. # The exact value doesn't matter as long as it's > 0
  181. mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6)
  182. mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6)
  183. # 4. Create Masks
  184. freqs_per_band = mel_filter_bank > 0
  185. assert freqs_per_band.any(dim=0).all(), (
  186. "all frequencies need to be covered by all bands"
  187. )
  188. # 5. Generate Indices
  189. repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
  190. freq_indices = repeated_freq_indices[freqs_per_band]
  191. if stereo:
  192. freq_indices = repeat(freq_indices, "f -> f s", s=2)
  193. # s=0 -> 2*f, s=1 -> 2*f+1
  194. freq_indices = freq_indices * 2 + torch.arange(2)
  195. freq_indices = rearrange(freq_indices, "f s -> (f s)")
  196. # 6. Aggregate Counts
  197. num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
  198. num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
  199. return {
  200. "freq_indices": freq_indices,
  201. "num_freqs_per_band": num_freqs_per_band,
  202. "num_bands_per_freq": num_bands_per_freq,
  203. "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved
  204. }
  205. # ============================================================================
  206. # Quantization Helper
  207. # ============================================================================
  208. def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
  209. mapping = {
  210. "f32": GGMLQuantizationType.F32,
  211. "fp32": GGMLQuantizationType.F32,
  212. "f16": GGMLQuantizationType.F16,
  213. "fp16": GGMLQuantizationType.F16,
  214. "q8_0": GGMLQuantizationType.Q8_0,
  215. "q4_0": GGMLQuantizationType.Q4_0,
  216. "q4_1": GGMLQuantizationType.Q4_1,
  217. "q5_0": GGMLQuantizationType.Q5_0,
  218. "q5_1": GGMLQuantizationType.Q5_1,
  219. }
  220. return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32)
  221. def get_file_type_id(qtype: GGMLQuantizationType) -> int:
  222. # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
  223. mapping = {
  224. GGMLQuantizationType.F32: 0,
  225. GGMLQuantizationType.F16: 1,
  226. GGMLQuantizationType.Q4_0: 2,
  227. GGMLQuantizationType.Q4_1: 3,
  228. # 4 is Q4_1_O (deprecated/legacy?)
  229. # 5 is Q4_0_O ?
  230. # 6 is Q4_1_O ?
  231. GGMLQuantizationType.Q8_0: 7,
  232. GGMLQuantizationType.Q5_0: 8,
  233. GGMLQuantizationType.Q5_1: 9,
  234. GGMLQuantizationType.Q2_K: 10,
  235. GGMLQuantizationType.Q3_K: 11,
  236. GGMLQuantizationType.Q4_K: 12,
  237. GGMLQuantizationType.Q5_K: 13,
  238. GGMLQuantizationType.Q6_K: 14,
  239. # IQ2_XXS etc might have IDs but let's stick to these for now
  240. }
  241. return mapping.get(qtype, 0) # Default to ALL_F32 if unknown
  242. def should_quantize(name: str) -> bool:
  243. """
  244. Determine if a tensor should be quantized.
  245. Keep norms and biases as FP32 to avoid CUDA alignment issues.
  246. """
  247. # Biases are always small and sensitive
  248. if "bias" in name:
  249. return False
  250. # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA
  251. if "norm.weight" in name:
  252. return False
  253. # Quantize all other "weight" matrices (Linear, Conv, Embedding if any)
  254. if "weight" in name:
  255. return True
  256. return False
  257. # ============================================================================
  258. # Key Name Mapping
  259. # ============================================================================
  260. def map_key_name(key: str) -> str:
  261. """
  262. Map PyTorch state_dict keys to GGUF format (blk.{bid}.*).
  263. Standardizes suffixes: gamma -> weight, beta -> bias.
  264. """
  265. def standardize_suffix(param_name: str) -> str:
  266. if param_name == "gamma":
  267. return "weight"
  268. if param_name == "beta":
  269. return "bias"
  270. return param_name
  271. parts = key.split(".")
  272. suffix = standardize_suffix(parts[-1])
  273. # Transformer Layers
  274. if key.startswith("layers."):
  275. layer_idx = parts[1]
  276. tf_idx = parts[2] # 0=Time, 1=Freq
  277. type_str = "time" if tf_idx == "0" else "freq"
  278. # Final Norm: layers.0.0.norm.gamma
  279. if len(parts) >= 5 and parts[3] == "norm":
  280. return f"blk.{layer_idx}.{type_str}_norm.{suffix}"
  281. # Sub-layers (Attention=0, FF=1)
  282. if len(parts) >= 6 and parts[3] == "layers":
  283. block_sub_idx = parts[5]
  284. if block_sub_idx == "0": # Attention
  285. if len(parts) > 6:
  286. sub_name = parts[6]
  287. if sub_name == "norm":
  288. return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}"
  289. if sub_name == "to_qkv":
  290. return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}"
  291. if sub_name == "to_out":
  292. return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}"
  293. if sub_name == "to_gates":
  294. return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}"
  295. elif block_sub_idx == "1": # FeedForward
  296. if len(parts) >= 8 and parts[6] == "net":
  297. net_idx = parts[7]
  298. if net_idx == "0":
  299. return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}"
  300. if net_idx == "1":
  301. return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}"
  302. if net_idx == "4":
  303. return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}"
  304. # BandSplit
  305. if key.startswith("band_split.to_features"):
  306. band_idx = parts[2]
  307. layer_idx = parts[3] # 0=Norm, 1=Linear
  308. if layer_idx == "0":
  309. return f"band_split.{band_idx}.norm.{suffix}"
  310. if layer_idx == "1":
  311. return f"band_split.{band_idx}.linear.{suffix}"
  312. # Mask Estimator
  313. if key.startswith("mask_estimators"):
  314. est_idx = parts[1]
  315. freq_idx = parts[3]
  316. layer_idx = parts[5] # 0, 2, 4
  317. return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
  318. # Final Norm
  319. if key.startswith("final_norm"):
  320. return f"final_norm.{suffix}"
  321. return key.replace(".", "_")
  322. # ============================================================================
  323. # Main Conversion
  324. # ============================================================================
  325. def convert(
  326. ckpt_path: str,
  327. output_path: str,
  328. config_path: str,
  329. dtype: str = "fp32",
  330. name: str | None = None,
  331. description: str | None = None,
  332. arch: str | None = None,
  333. ):
  334. """
  335. Convert PyTorch checkpoint to GGUF format.
  336. """
  337. print(f"Loading checkpoint: {ckpt_path}")
  338. if ckpt_path.endswith(".safetensors"):
  339. state_dict = load_safetensors(ckpt_path)
  340. else:
  341. checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
  342. if "state_dict" in checkpoint:
  343. state_dict = checkpoint["state_dict"]
  344. elif "model" in checkpoint:
  345. state_dict = checkpoint["model"]
  346. else:
  347. state_dict = checkpoint
  348. print(f"Loading config: {config_path}")
  349. with open(config_path) as f:
  350. config_dict = yaml.safe_load(f)
  351. # Detect architecture
  352. if arch is None:
  353. try:
  354. arch = detect_architecture(config_dict)
  355. print(f"Auto-detected architecture: {arch}")
  356. except ValueError as e:
  357. print(f"Error: {e}")
  358. return
  359. else:
  360. # Normalize provided arch to full name
  361. arch = normalize_arch(arch)
  362. # Generate buffers
  363. print("Generating buffers (standalone)...")
  364. buffers = generate_buffers(config_dict, arch=arch)
  365. freq_indices = buffers["freq_indices"]
  366. num_bands_per_freq = buffers["num_bands_per_freq"]
  367. num_freqs_per_band = buffers["num_freqs_per_band"]
  368. arch_name = arch
  369. # Create GGUF writer
  370. gguf_writer = gguf.GGUFWriter(output_path, arch_name)
  371. # =========================================================================
  372. # 1. Write Standard GGUF Metadata
  373. # =========================================================================
  374. print("Writing metadata...")
  375. # General metadata
  376. model_name = name if name else "BSRoformer Separator"
  377. model_description = description if description else "Music source separation model"
  378. gguf_writer.add_name(model_name)
  379. gguf_writer.add_description(model_description)
  380. # Determine types
  381. target_qtype = get_target_quantization_type(dtype)
  382. file_type_id = get_file_type_id(target_qtype)
  383. gguf_writer.add_file_type(file_type_id)
  384. # Write Architecture
  385. # gguf_writer.add_string(f"{arch_name}.architecture", arch) # Redundant with general.architecture
  386. if arch_name == "bs_roformer" and "freqs_per_bands_tuple" in buffers:
  387. freqs_tuple = buffers["freqs_per_bands_tuple"]
  388. # Must be list for GGUFWriter
  389. gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
  390. if arch_name == "bs_roformer_v2":
  391. gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(config_dict["freqs_per_bands"]))
  392. gguf_writer.add_array(f"{arch_name}.freqs_per_bands_out", list(config_dict["freqs_per_bands_out"]))
  393. # Quantization version (required when quantized)
  394. if target_qtype != GGMLQuantizationType.F32:
  395. gguf_writer.add_quantization_version(2)
  396. # Calculate parameter count
  397. total_params = 0
  398. for key, tensor in state_dict.items():
  399. if "freq_indices" in key or "num_bands" in key:
  400. continue
  401. total_params += tensor.numel()
  402. print(f"Total parameters: {total_params}")
  403. gguf_writer.add_uint64("general.parameter_count", total_params)
  404. # =========================================================================
  405. # 2. Write Hyperparameters
  406. # =========================================================================
  407. print("Writing hyperparameters...")
  408. hparams = config_dict
  409. # Load state dict directly (no model class dependency)
  410. print(f"Loading checkpoint for architecture: {arch}")
  411. raw_state_dict = state_dict
  412. if raw_state_dict is None:
  413. raise ValueError("Could not find state_dict in checkpoint")
  414. # Clean up state dict (handle DDP "module." prefix)
  415. state_dict = {}
  416. for k, v in raw_state_dict.items():
  417. if k.startswith("module."):
  418. k = k[7:]
  419. state_dict[k] = v
  420. # Architecture specific parameters
  421. gguf_writer.add_uint32(f"{arch_name}.dim", hparams["hidden_size"])
  422. gguf_writer.add_uint32(f"{arch_name}.depth", hparams["num_hidden_layers"])
  423. # BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
  424. num_bands = buffers.get("num_bands", len(hparams.get("freqs_per_bands", [])))
  425. gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
  426. # STFT parameters
  427. gguf_writer.add_uint32(f"{arch_name}.stft_n_fft", hparams.get("stft_n_fft", 2048))
  428. # Remove default for hop_length, must be present or fail/warn
  429. gguf_writer.add_uint32(
  430. f"{arch_name}.stft_hop_length", hparams.get("stft_hop_length", 441)
  431. )
  432. gguf_writer.add_uint32(
  433. f"{arch_name}.stft_win_length", hparams.get("stft_win_length", 2048)
  434. )
  435. gguf_writer.add_bool(
  436. f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
  437. )
  438. gguf_writer.add_bool(
  439. f"{arch_name}.zero_dc", hparams.get("zero_dc", True) # Defaults to True in reference implementation
  440. )
  441. # Architecture details
  442. gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
  443. gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
  444. gguf_writer.add_uint32(
  445. f"{arch_name}.sample_rate", hparams.get("wave_sample_rate", 44100)
  446. )
  447. if arch_name == "bs_roformer_v2":
  448. gguf_writer.add_uint32(
  449. f"{arch_name}.time_transformer_depth",
  450. hparams.get("time_transformer_depth", 1),
  451. )
  452. gguf_writer.add_uint32(
  453. f"{arch_name}.freq_transformer_depth",
  454. hparams.get("freq_transformer_depth", 1),
  455. )
  456. gguf_writer.add_uint32(
  457. f"{arch_name}.num_key_value_heads", hparams.get("num_key_value_heads", 4)
  458. )
  459. gguf_writer.add_uint32(
  460. f"{arch_name}.intermediate_size", hparams.get("intermediate_size", 1152)
  461. )
  462. gguf_writer.add_uint32(
  463. f"{arch_name}.num_input_channels", hparams.get("num_input_channels", 2)
  464. )
  465. gguf_writer.add_uint32(
  466. f"{arch_name}.band_proj_size", hparams.get("band_proj_size", 256)
  467. )
  468. gguf_writer.add_uint32(
  469. f"{arch_name}.register_token_num", hparams.get("register_token_num", 4)
  470. )
  471. else:
  472. gguf_writer.add_uint32(
  473. f"{arch_name}.time_transformer_depth",
  474. hparams.get("time_transformer_depth", 0),
  475. )
  476. gguf_writer.add_uint32(
  477. f"{arch_name}.freq_transformer_depth",
  478. hparams.get("freq_transformer_depth", 0),
  479. )
  480. gguf_writer.add_uint32(
  481. f"{arch_name}.linear_transformer_depth",
  482. hparams.get("linear_transformer_depth", 0),
  483. )
  484. gguf_writer.add_uint32(
  485. f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
  486. )
  487. gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("head_dim", 64))
  488. gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("num_attention_heads", 8))
  489. gguf_writer.add_uint32(
  490. f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
  491. )
  492. gguf_writer.add_bool(
  493. f"{arch_name}.skip_connection", hparams.get("skip_connection", False)
  494. )
  495. # =========================================================================
  496. # 3. Write Inference Defaults (Optional, can be overridden at runtime)
  497. # =========================================================================
  498. print("Writing inference defaults...")
  499. inference_config = config_dict.get("inference", {})
  500. audio_config = config_dict.get("audio", {})
  501. # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
  502. default_chunk_size = hparams.get(
  503. "wave_chunk_size", 352800
  504. )
  505. # num_overlap: from inference section
  506. default_num_overlap = inference_config.get("num_overlap", 2)
  507. gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
  508. gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
  509. # =========================================================================
  510. # 4. Write Buffers (Always FP32/I32)
  511. # =========================================================================
  512. print("Writing buffers...")
  513. # freq_indices (int32) - may be torch.Tensor (MelBand) or np.ndarray (BS)
  514. fi = freq_indices.numpy() if hasattr(freq_indices, "numpy") else freq_indices
  515. gguf_writer.add_tensor("buffer_freq_indices", fi.astype(np.int32))
  516. # num_bands_per_freq (int32)
  517. nbpf = (
  518. num_bands_per_freq.numpy()
  519. if hasattr(num_bands_per_freq, "numpy")
  520. else num_bands_per_freq
  521. )
  522. gguf_writer.add_tensor("buffer_num_bands_per_freq", nbpf.astype(np.int32))
  523. # num_freqs_per_band (int32)
  524. nfpb = (
  525. num_freqs_per_band.numpy()
  526. if hasattr(num_freqs_per_band, "numpy")
  527. else num_freqs_per_band
  528. )
  529. gguf_writer.add_tensor("buffer_num_freqs_per_band", nfpb.astype(np.int32))
  530. # =========================================================================
  531. # 5. Write Weights (Mixed Quantization)
  532. # =========================================================================
  533. print(f"Writing weights ({dtype} -> {target_qtype.name})...")
  534. print("Strategy: Quantize weights, Keep Norm/Bias as F32")
  535. n_tensors = 0
  536. n_quantized = 0
  537. warnings_list = []
  538. for key, tensor in state_dict.items():
  539. new_key = map_key_name(key)
  540. # Skip buffers
  541. if (
  542. "freq_indices" in key
  543. or "num_bands_per_freq" in key
  544. or "num_freqs_per_band" in key
  545. ):
  546. continue
  547. data = tensor.numpy().astype(np.float32)
  548. # Decide whether to quantize
  549. is_quantized = False
  550. if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key):
  551. try:
  552. # Use gguf-py built-in quantization
  553. quantized_data = quantize(data, target_qtype)
  554. # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types)
  555. # or float array (for F16)
  556. gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype)
  557. is_quantized = True
  558. n_quantized += 1
  559. except Exception as e:
  560. msg = f"Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
  561. warnings_list.append(msg)
  562. gguf_writer.add_tensor(new_key, data)
  563. else:
  564. # Keep as F32
  565. gguf_writer.add_tensor(new_key, data)
  566. status = target_qtype.name if is_quantized else "F32"
  567. print(f" {new_key:<50} | {str(data.shape):<20} | {status}")
  568. n_tensors += 1
  569. # =========================================================================
  570. # 6. Write File
  571. # =========================================================================
  572. print(f"\nWriting GGUF to {output_path}")
  573. gguf_writer.write_header_to_file()
  574. gguf_writer.write_kv_data_to_file()
  575. gguf_writer.write_tensors_to_file()
  576. gguf_writer.close()
  577. if warnings_list:
  578. print("\n" + "=" * 80)
  579. print(
  580. f"WARNING: {len(warnings_list)} tensors failed to quantize (fallback to F32):"
  581. )
  582. for msg in warnings_list:
  583. print(f" - {msg}")
  584. print("=" * 80 + "\n")
  585. file_size = os.path.getsize(output_path)
  586. print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
  587. print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
  588. if __name__ == "__main__":
  589. parser = argparse.ArgumentParser(
  590. description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization",
  591. formatter_class=argparse.RawDescriptionHelpFormatter,
  592. epilog="""
  593. Examples:
  594. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16
  595. python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0
  596. """,
  597. )
  598. parser.add_argument(
  599. "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
  600. )
  601. parser.add_argument("--config", type=str, required=True, help="Path to YAML or JSON config")
  602. parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
  603. parser.add_argument(
  604. "--dtype",
  605. type=str,
  606. default="fp32",
  607. choices=[
  608. "fp32",
  609. "f32",
  610. "fp16",
  611. "f16",
  612. "q8_0",
  613. "q4_0",
  614. "q4_1",
  615. "q5_0",
  616. "q5_1",
  617. ],
  618. help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
  619. )
  620. parser.add_argument(
  621. "--name",
  622. type=str,
  623. default=None,
  624. help="Model name (default: 'BSRoformer Vocal Separator')",
  625. )
  626. parser.add_argument(
  627. "--description",
  628. type=str,
  629. default=None,
  630. help="Model description (default: 'Audio source separation model for vocal extraction')",
  631. )
  632. parser.add_argument(
  633. "--arch",
  634. choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer", "bs_roformer_v2"],
  635. default=None,
  636. help="Architecture type (auto-detected if not specified)",
  637. )
  638. args = parser.parse_args()
  639. convert(
  640. args.ckpt,
  641. args.out,
  642. args.config,
  643. args.dtype,
  644. args.name,
  645. args.description,
  646. args.arch,
  647. )