| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- #!/usr/bin/env python3
- """
- Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format.
- Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
- Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
- """
- import os
- import argparse
- import torch
- import numpy as np
- import yaml
- import librosa
- from einops import repeat, reduce, rearrange
- import gguf
- from gguf.quants import quantize, GGMLQuantizationType
- def generate_buffers(hparams):
- """
- Generate the buffers (freq_indices, num_bands_per_freq, etc.)
- mimicking the logic in MelBandRoformer.__init__.
- """
- num_bands = hparams["num_bands"]
- sample_rate = hparams.get("sample_rate", 44100)
- stft_n_fft = hparams.get("stft_n_fft", 2048)
- stereo = hparams.get("stereo", False)
- # 1. Calculate number of frequencies
- freqs = stft_n_fft // 2 + 1
- # 2. Create Mel Filter Bank
- mel_filter_bank_numpy = librosa.filters.mel(
- sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
- )
- mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
- # 3. Ensure edge values are positive (required for mask generation)
- # The exact value doesn't matter as long as it's > 0
- mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6)
- mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6)
- # 4. Create Masks
- freqs_per_band = mel_filter_bank > 0
- assert freqs_per_band.any(dim=0).all(), (
- "all frequencies need to be covered by all bands"
- )
- # 5. Generate Indices
- repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
- freq_indices = repeated_freq_indices[freqs_per_band]
- if stereo:
- freq_indices = repeat(freq_indices, "f -> f s", s=2)
- # s=0 -> 2*f, s=1 -> 2*f+1
- freq_indices = freq_indices * 2 + torch.arange(2)
- freq_indices = rearrange(freq_indices, "f s -> (f s)")
- # 6. Aggregate Counts
- num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
- num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
- return {
- "freq_indices": freq_indices,
- "num_freqs_per_band": num_freqs_per_band,
- "num_bands_per_freq": num_bands_per_freq,
- "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved
- }
- # ============================================================================
- # Quantization Helper
- # ============================================================================
- def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
- mapping = {
- "f32": GGMLQuantizationType.F32,
- "fp32": GGMLQuantizationType.F32,
- "f16": GGMLQuantizationType.F16,
- "fp16": GGMLQuantizationType.F16,
- "q8_0": GGMLQuantizationType.Q8_0,
- "q4_0": GGMLQuantizationType.Q4_0,
- "q4_1": GGMLQuantizationType.Q4_1,
- "q5_0": GGMLQuantizationType.Q5_0,
- "q5_1": GGMLQuantizationType.Q5_1,
- }
- return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32)
- def get_file_type_id(qtype: GGMLQuantizationType) -> int:
- # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
- mapping = {
- GGMLQuantizationType.F32: 0,
- GGMLQuantizationType.F16: 1,
- GGMLQuantizationType.Q4_0: 2,
- GGMLQuantizationType.Q4_1: 3,
- # 4 is Q4_1_O (deprecated/legacy?)
- # 5 is Q4_0_O ?
- # 6 is Q4_1_O ?
- GGMLQuantizationType.Q8_0: 7,
- GGMLQuantizationType.Q5_0: 8,
- GGMLQuantizationType.Q5_1: 9,
- GGMLQuantizationType.Q2_K: 10,
- GGMLQuantizationType.Q3_K: 11,
- GGMLQuantizationType.Q4_K: 12,
- GGMLQuantizationType.Q5_K: 13,
- GGMLQuantizationType.Q6_K: 14,
- # IQ2_XXS etc might have IDs but let's stick to these for now
- }
- return mapping.get(qtype, 0) # Default to ALL_F32 if unknown
- def should_quantize(name: str) -> bool:
- """
- Determine if a tensor should be quantized.
- Keep norms and biases as FP32 to avoid CUDA alignment issues.
- """
- # Biases are always small and sensitive
- if "bias" in name:
- return False
- # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA
- if "norm.weight" in name:
- return False
- # Quantize all other "weight" matrices (Linear, Conv, Embedding if any)
- if "weight" in name:
- return True
- return False
- # ============================================================================
- # Key Name Mapping
- # ============================================================================
- def map_key_name(key: str) -> str:
- """
- Map PyTorch state_dict keys to GGUF format (blk.{bid}.*).
- Standardizes suffixes: gamma -> weight, beta -> bias.
- """
- def standardize_suffix(param_name: str) -> str:
- if param_name == "gamma":
- return "weight"
- if param_name == "beta":
- return "bias"
- return param_name
- parts = key.split(".")
- suffix = standardize_suffix(parts[-1])
- # Transformer Layers
- if key.startswith("layers."):
- layer_idx = parts[1]
- tf_idx = parts[2] # 0=Time, 1=Freq
- type_str = "time" if tf_idx == "0" else "freq"
- # Final Norm: layers.0.0.norm.gamma
- if len(parts) >= 5 and parts[3] == "norm":
- return f"blk.{layer_idx}.{type_str}_norm.{suffix}"
- # Sub-layers (Attention=0, FF=1)
- if len(parts) >= 6 and parts[3] == "layers":
- block_sub_idx = parts[5]
- if block_sub_idx == "0": # Attention
- if len(parts) > 6:
- sub_name = parts[6]
- if sub_name == "norm":
- return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}"
- if sub_name == "to_qkv":
- return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}"
- if sub_name == "to_out":
- return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}"
- if sub_name == "to_gates":
- return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}"
- elif block_sub_idx == "1": # FeedForward
- if len(parts) >= 8 and parts[6] == "net":
- net_idx = parts[7]
- if net_idx == "0":
- return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}"
- if net_idx == "1":
- return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}"
- if net_idx == "4":
- return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}"
- # BandSplit
- if key.startswith("band_split.to_features"):
- band_idx = parts[2]
- layer_idx = parts[3] # 0=Norm, 1=Linear
- if layer_idx == "0":
- return f"band_split.{band_idx}.norm.{suffix}"
- if layer_idx == "1":
- return f"band_split.{band_idx}.linear.{suffix}"
- # Mask Estimator
- if key.startswith("mask_estimators"):
- est_idx = parts[1]
- freq_idx = parts[3]
- layer_idx = parts[5] # 0, 2, 4
- return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
- return key.replace(".", "_")
- # ============================================================================
- # Main Conversion
- # ============================================================================
- def convert(
- ckpt_path: str,
- output_path: str,
- config_path: str,
- dtype: str = "fp32",
- ):
- """
- Convert PyTorch checkpoint to GGUF format.
- """
- print(f"Loading checkpoint: {ckpt_path}")
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- if "state_dict" in checkpoint:
- state_dict = checkpoint["state_dict"]
- elif "model" in checkpoint:
- state_dict = checkpoint["model"]
- else:
- state_dict = checkpoint
- print(f"Loading config: {config_path}")
- with open(config_path) as f:
- config_dict = yaml.load(f, Loader=yaml.FullLoader)
- # Generate buffers
- print("Generating buffers (standalone)...")
- buffers = generate_buffers(config_dict["model"])
- freq_indices = buffers["freq_indices"]
- num_bands_per_freq = buffers["num_bands_per_freq"]
- num_freqs_per_band = buffers["num_freqs_per_band"]
- # Create GGUF writer
- gguf_writer = gguf.GGUFWriter(output_path, "mel_band_roformer")
- # =========================================================================
- # 1. Write Standard GGUF Metadata
- # =========================================================================
- print("Writing metadata...")
- # General metadata
- gguf_writer.add_name("Mel-Band-Roformer Vocal Separator")
- gguf_writer.add_description("Audio source separation model for vocal extraction")
- # Determine types
- target_qtype = get_target_quantization_type(dtype)
- file_type_id = get_file_type_id(target_qtype)
- gguf_writer.add_file_type(file_type_id)
- # Quantization version (required when quantized)
- if target_qtype != GGMLQuantizationType.F32:
- gguf_writer.add_quantization_version(2)
- # Calculate parameter count
- total_params = 0
- for key, tensor in state_dict.items():
- if "freq_indices" in key or "num_bands" in key:
- continue
- total_params += tensor.numel()
- print(f"Total parameters: {total_params}")
- gguf_writer.add_uint64("general.parameter_count", total_params)
- # =========================================================================
- # 2. Write Hyperparameters
- # =========================================================================
- print("Writing hyperparameters...")
- hparams = config_dict["model"]
- # Architecture specific parameters
- gguf_writer.add_uint32("mel_band_roformer.dim", hparams["dim"])
- gguf_writer.add_uint32("mel_band_roformer.depth", hparams["depth"])
- gguf_writer.add_uint32("mel_band_roformer.num_bands", hparams["num_bands"])
- # STFT parameters
- gguf_writer.add_uint32(
- "mel_band_roformer.stft_n_fft", hparams.get("stft_n_fft", 2048)
- )
- # Remove default for hop_length, must be present or fail/warn
- gguf_writer.add_uint32(
- "mel_band_roformer.stft_hop_length", hparams.get("stft_hop_length", 441)
- )
- gguf_writer.add_uint32(
- "mel_band_roformer.stft_win_length", hparams.get("stft_win_length", 2048)
- )
- gguf_writer.add_bool(
- "mel_band_roformer.stft_normalized", hparams.get("stft_normalized", False)
- )
- gguf_writer.add_bool(
- "mel_band_roformer.zero_dc", hparams.get("zero_dc", True)
- ) # Defaults to True in reference implementation
- # Architecture details
- gguf_writer.add_uint32("mel_band_roformer.num_stems", hparams.get("num_stems", 1))
- gguf_writer.add_bool("mel_band_roformer.stereo", hparams.get("stereo", False))
- gguf_writer.add_uint32(
- "mel_band_roformer.sample_rate", hparams.get("sample_rate", 44100)
- )
- gguf_writer.add_uint32(
- "mel_band_roformer.time_transformer_depth",
- hparams.get("time_transformer_depth", 0),
- )
- gguf_writer.add_uint32(
- "mel_band_roformer.freq_transformer_depth",
- hparams.get("freq_transformer_depth", 0),
- )
- gguf_writer.add_uint32(
- "mel_band_roformer.linear_transformer_depth",
- hparams.get("linear_transformer_depth", 0),
- )
- gguf_writer.add_uint32(
- "mel_band_roformer.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
- )
- gguf_writer.add_uint32("mel_band_roformer.dim_head", hparams.get("dim_head", 64))
- gguf_writer.add_uint32("mel_band_roformer.heads", hparams.get("heads", 8))
- gguf_writer.add_uint32(
- "mel_band_roformer.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
- )
- gguf_writer.add_bool(
- "mel_band_roformer.skip_connection", hparams.get("skip_connection", False)
- )
- # =========================================================================
- # 3. Write Inference Defaults (Optional, can be overridden at runtime)
- # =========================================================================
- print("Writing inference defaults...")
- inference_config = config_dict.get("inference", {})
- audio_config = config_dict.get("audio", {})
- # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
- default_chunk_size = inference_config.get(
- "chunk_size", audio_config.get("chunk_size", 352800)
- )
- # num_overlap: from inference section
- default_num_overlap = inference_config.get("num_overlap", 0)
- gguf_writer.add_uint32("mel_band_roformer.default_chunk_size", default_chunk_size)
- gguf_writer.add_uint32("mel_band_roformer.default_num_overlap", default_num_overlap)
- # =========================================================================
- # 4. Write Buffers (Always FP32/I32)
- # =========================================================================
- print("Writing buffers...")
- # freq_indices (int32)
- gguf_writer.add_tensor("buffer_freq_indices", freq_indices.numpy().astype(np.int32))
- # num_bands_per_freq (int32)
- gguf_writer.add_tensor(
- "buffer_num_bands_per_freq", num_bands_per_freq.numpy().astype(np.int32)
- )
- # num_freqs_per_band (int32)
- gguf_writer.add_tensor(
- "buffer_num_freqs_per_band", num_freqs_per_band.numpy().astype(np.int32)
- )
- # =========================================================================
- # 5. Write Weights (Mixed Quantization)
- # =========================================================================
- print(f"Writing weights ({dtype} -> {target_qtype.name})...")
- print("Strategy: Quantize weights, Keep Norm/Bias as F32")
- n_tensors = 0
- n_quantized = 0
- for key, tensor in state_dict.items():
- new_key = map_key_name(key)
- # Skip buffers
- if (
- "freq_indices" in key
- or "num_bands_per_freq" in key
- or "num_freqs_per_band" in key
- ):
- continue
- data = tensor.numpy().astype(np.float32)
- # Decide whether to quantize
- is_quantized = False
- if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key):
- try:
- # Use gguf-py built-in quantization
- quantized_data = quantize(data, target_qtype)
- # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types)
- # or float array (for F16)
- gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype)
- is_quantized = True
- n_quantized += 1
- except Exception as e:
- print(
- f"Warning: Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
- )
- gguf_writer.add_tensor(new_key, data)
- else:
- # Keep as F32
- gguf_writer.add_tensor(new_key, data)
- status = target_qtype.name if is_quantized else "F32"
- print(f" {new_key:<50} | {str(data.shape):<20} | {status}")
- n_tensors += 1
- # =========================================================================
- # 6. Write File
- # =========================================================================
- print(f"\nWriting GGUF to {output_path}")
- gguf_writer.write_header_to_file()
- gguf_writer.write_kv_data_to_file()
- gguf_writer.write_tensors_to_file()
- gguf_writer.close()
- file_size = os.path.getsize(output_path)
- print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
- print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- Examples:
- python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16
- python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0
- """,
- )
- parser.add_argument(
- "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
- )
- parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
- parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
- parser.add_argument(
- "--dtype",
- type=str,
- default="fp32",
- choices=[
- "fp32",
- "f32",
- "fp16",
- "f16",
- "q8_0",
- "q4_0",
- "q4_1",
- "q5_0",
- "q5_1",
- ],
- help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
- )
- args = parser.parse_args()
- convert(args.ckpt, args.out, args.config, args.dtype)
|