#!/usr/bin/env python3 """ Convert Mel-Band-Roformer PyTorch checkpoint to GGUF format. Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1 Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues. """ import os import argparse import torch import numpy as np import yaml import librosa from einops import repeat, reduce, rearrange import gguf from gguf.quants import quantize, GGMLQuantizationType def detect_architecture(config_dict): """ Detect architecture from config. Returns: 'bs_roformer' or 'mel_band_roformer' """ # Check structural signatures in 'model' section model_config = config_dict.get("model", {}) has_freqs = "freqs_per_bands" in model_config has_num_bands = "num_bands" in model_config if has_freqs: return "bs_roformer" if has_num_bands: return "mel_band_roformer" # 3. If neither found, fail raise ValueError( "Auto-detection failed: Config missing 'freqs_per_bands' (BS) or 'num_bands' (Mel-Band). " "Please specify --arch manually." ) def normalize_arch(arch: str) -> str: """Normalize architecture name to full GGUF name.""" mapping = { "bs": "bs_roformer", "bs_roformer": "bs_roformer", "mel": "mel_band_roformer", "mel_band": "mel_band_roformer", "mel_band_roformer": "mel_band_roformer", } result = mapping.get(arch.lower()) if result is None: raise ValueError( f"Unknown architecture: '{arch}'. Supported: {list(mapping.keys())}" ) return result def generate_buffers_bs(hparams): """BS Roformer: 从 freqs_per_bands 元组生成缓冲区""" # Default from bs_roformer.py DEFAULT_FREQS_PER_BANDS = ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129, ) freqs_per_bands = hparams.get("freqs_per_bands", DEFAULT_FREQS_PER_BANDS) stereo = hparams.get("stereo", False) audio_channels = 2 if stereo else 1 # Validate stft_n_fft = hparams.get("stft_n_fft", 2048) expected_freqs = stft_n_fft // 2 + 1 # Check sum sum_freqs = sum(freqs_per_bands) if sum_freqs != expected_freqs: print( f"[WARNING] sum(freqs_per_bands)={sum_freqs} != expected {expected_freqs}. Adjusting last band..." ) # Note: In C++ logic relying on exact match might be strict, but let's warn for now. # Actually BS Roformer paper/code implies strict match for STFT reconstruction. num_bands = len(freqs_per_bands) freqs_per_bands_with_complex = tuple( 2 * f * audio_channels for f in freqs_per_bands ) # num_freqs_per_band: i32 array num_freqs_per_band = np.array(freqs_per_bands, dtype=np.int32) # BS doesn't use freq_indices re-indexing, but to keep compatible file structure # we create dummy full-range indices. total_freqs_stereo = expected_freqs * audio_channels freq_indices = np.arange(total_freqs_stereo, dtype=np.int32) num_bands_per_freq = np.ones(expected_freqs, dtype=np.int32) print(f"Generated BS buffers: {num_bands} bands, {len(freq_indices)} indices") return { "freq_indices": freq_indices, "num_freqs_per_band": num_freqs_per_band, "num_bands_per_freq": num_bands_per_freq, "num_bands": num_bands, "freqs_per_bands_with_complex": freqs_per_bands_with_complex, "freqs_per_bands_tuple": freqs_per_bands, # Keep raw tuple for metadata } def generate_buffers(hparams, arch="mel_band_roformer"): """ Generate buffers for the specified architecture. Args: hparams: Model hyperparameters arch: Architecture name ('bs_roformer' or 'mel_band_roformer') """ if arch == "bs_roformer": return generate_buffers_bs(hparams) # Mel-Band-Roformer Logic # ------------------------------------------------------------------------ """ Generate the buffers (freq_indices, num_bands_per_freq, etc.) mimicking the logic in MelBandRoformer.__init__. """ num_bands = hparams["num_bands"] sample_rate = hparams.get("sample_rate", 44100) stft_n_fft = hparams.get("stft_n_fft", 2048) stereo = hparams.get("stereo", False) # 1. Calculate number of frequencies freqs = stft_n_fft // 2 + 1 # 2. Create Mel Filter Bank mel_filter_bank_numpy = librosa.filters.mel( sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands ) mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) # 3. Ensure edge values are positive (required for mask generation) # The exact value doesn't matter as long as it's > 0 mel_filter_bank[0, 0] = max(mel_filter_bank[0, 0].item(), 1e-6) mel_filter_bank[-1, -1] = max(mel_filter_bank[-1, -1].item(), 1e-6) # 4. Create Masks freqs_per_band = mel_filter_bank > 0 assert freqs_per_band.any(dim=0).all(), ( "all frequencies need to be covered by all bands" ) # 5. Generate Indices repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands) freq_indices = repeated_freq_indices[freqs_per_band] if stereo: freq_indices = repeat(freq_indices, "f -> f s", s=2) # s=0 -> 2*f, s=1 -> 2*f+1 freq_indices = freq_indices * 2 + torch.arange(2) freq_indices = rearrange(freq_indices, "f s -> (f s)") # 6. Aggregate Counts num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum") num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum") return { "freq_indices": freq_indices, "num_freqs_per_band": num_freqs_per_band, "num_bands_per_freq": num_bands_per_freq, "freqs_per_band": freqs_per_band, # Kept if needed, though usually not saved } # ============================================================================ # Quantization Helper # ============================================================================ def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType: mapping = { "f32": GGMLQuantizationType.F32, "fp32": GGMLQuantizationType.F32, "f16": GGMLQuantizationType.F16, "fp16": GGMLQuantizationType.F16, "q8_0": GGMLQuantizationType.Q8_0, "q4_0": GGMLQuantizationType.Q4_0, "q4_1": GGMLQuantizationType.Q4_1, "q5_0": GGMLQuantizationType.Q5_0, "q5_1": GGMLQuantizationType.Q5_1, } return mapping.get(dtype_str.lower(), GGMLQuantizationType.F32) def get_file_type_id(qtype: GGMLQuantizationType) -> int: # See GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md mapping = { GGMLQuantizationType.F32: 0, GGMLQuantizationType.F16: 1, GGMLQuantizationType.Q4_0: 2, GGMLQuantizationType.Q4_1: 3, # 4 is Q4_1_O (deprecated/legacy?) # 5 is Q4_0_O ? # 6 is Q4_1_O ? GGMLQuantizationType.Q8_0: 7, GGMLQuantizationType.Q5_0: 8, GGMLQuantizationType.Q5_1: 9, GGMLQuantizationType.Q2_K: 10, GGMLQuantizationType.Q3_K: 11, GGMLQuantizationType.Q4_K: 12, GGMLQuantizationType.Q5_K: 13, GGMLQuantizationType.Q6_K: 14, # IQ2_XXS etc might have IDs but let's stick to these for now } return mapping.get(qtype, 0) # Default to ALL_F32 if unknown def should_quantize(name: str) -> bool: """ Determine if a tensor should be quantized. Keep norms and biases as FP32 to avoid CUDA alignment issues. """ # Biases are always small and sensitive if "bias" in name: return False # Norm weights (gamma) must be F32 to avoid mixed-type mul issues in CUDA if "norm.weight" in name: return False # Quantize all other "weight" matrices (Linear, Conv, Embedding if any) if "weight" in name: return True return False # ============================================================================ # Key Name Mapping # ============================================================================ def map_key_name(key: str) -> str: """ Map PyTorch state_dict keys to GGUF format (blk.{bid}.*). Standardizes suffixes: gamma -> weight, beta -> bias. """ def standardize_suffix(param_name: str) -> str: if param_name == "gamma": return "weight" if param_name == "beta": return "bias" return param_name parts = key.split(".") suffix = standardize_suffix(parts[-1]) # Transformer Layers if key.startswith("layers."): layer_idx = parts[1] tf_idx = parts[2] # 0=Time, 1=Freq type_str = "time" if tf_idx == "0" else "freq" # Final Norm: layers.0.0.norm.gamma if len(parts) >= 5 and parts[3] == "norm": return f"blk.{layer_idx}.{type_str}_norm.{suffix}" # Sub-layers (Attention=0, FF=1) if len(parts) >= 6 and parts[3] == "layers": block_sub_idx = parts[5] if block_sub_idx == "0": # Attention if len(parts) > 6: sub_name = parts[6] if sub_name == "norm": return f"blk.{layer_idx}.{type_str}_attn_norm.{suffix}" if sub_name == "to_qkv": return f"blk.{layer_idx}.{type_str}_attn_qkv.{suffix}" if sub_name == "to_out": return f"blk.{layer_idx}.{type_str}_attn_out.{suffix}" if sub_name == "to_gates": return f"blk.{layer_idx}.{type_str}_attn_gate.{suffix}" elif block_sub_idx == "1": # FeedForward if len(parts) >= 8 and parts[6] == "net": net_idx = parts[7] if net_idx == "0": return f"blk.{layer_idx}.{type_str}_ff_norm.{suffix}" if net_idx == "1": return f"blk.{layer_idx}.{type_str}_ff_in.{suffix}" if net_idx == "4": return f"blk.{layer_idx}.{type_str}_ff_out.{suffix}" # BandSplit if key.startswith("band_split.to_features"): band_idx = parts[2] layer_idx = parts[3] # 0=Norm, 1=Linear if layer_idx == "0": return f"band_split.{band_idx}.norm.{suffix}" if layer_idx == "1": return f"band_split.{band_idx}.linear.{suffix}" # Mask Estimator if key.startswith("mask_estimators"): est_idx = parts[1] freq_idx = parts[3] layer_idx = parts[5] # 0, 2, 4 return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}" # Final Norm if key.startswith("final_norm"): return f"final_norm.{suffix}" return key.replace(".", "_") # ============================================================================ # Main Conversion # ============================================================================ def convert( ckpt_path: str, output_path: str, config_path: str, dtype: str = "fp32", name: str | None = None, description: str | None = None, arch: str | None = None, ): """ Convert PyTorch checkpoint to GGUF format. """ print(f"Loading checkpoint: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] elif "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint print(f"Loading config: {config_path}") with open(config_path) as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) # Detect architecture if arch is None: try: arch = detect_architecture(config_dict) print(f"Auto-detected architecture: {arch}") except ValueError as e: print(f"Error: {e}") return else: # Normalize provided arch to full name arch = normalize_arch(arch) # Generate buffers print("Generating buffers (standalone)...") buffers = generate_buffers(config_dict["model"], arch=arch) freq_indices = buffers["freq_indices"] num_bands_per_freq = buffers["num_bands_per_freq"] num_freqs_per_band = buffers["num_freqs_per_band"] arch_name = arch # Create GGUF writer gguf_writer = gguf.GGUFWriter(output_path, arch_name) # ========================================================================= # 1. Write Standard GGUF Metadata # ========================================================================= print("Writing metadata...") # General metadata model_name = name if name else "Mel-Band-Roformer Separator" model_description = description if description else "Music source separation model" gguf_writer.add_name(model_name) gguf_writer.add_description(model_description) # Determine types target_qtype = get_target_quantization_type(dtype) file_type_id = get_file_type_id(target_qtype) gguf_writer.add_file_type(file_type_id) # Write Architecture # gguf_writer.add_string(f"{arch_name}.architecture", arch) # Redundant with general.architecture if arch_name == "bs_roformer" and "freqs_per_bands_tuple" in buffers: freqs_tuple = buffers["freqs_per_bands_tuple"] # Must be list for GGUFWriter gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple)) # Quantization version (required when quantized) if target_qtype != GGMLQuantizationType.F32: gguf_writer.add_quantization_version(2) # Calculate parameter count total_params = 0 for key, tensor in state_dict.items(): if "freq_indices" in key or "num_bands" in key: continue total_params += tensor.numel() print(f"Total parameters: {total_params}") gguf_writer.add_uint64("general.parameter_count", total_params) # ========================================================================= # 2. Write Hyperparameters # ========================================================================= print("Writing hyperparameters...") hparams = config_dict["model"] # Load state dict directly (no model class dependency) print(f"Loading checkpoint for architecture: {arch}") raw_state_dict = None if "state_dict" in checkpoint: raw_state_dict = checkpoint["state_dict"] elif "model" in checkpoint: raw_state_dict = checkpoint["model"] else: raw_state_dict = checkpoint if raw_state_dict is None: raise ValueError("Could not find state_dict in checkpoint") # Clean up state dict (handle DDP "module." prefix) state_dict = {} for k, v in raw_state_dict.items(): if k.startswith("module."): k = k[7:] state_dict[k] = v # Architecture specific parameters gguf_writer.add_uint32(f"{arch_name}.dim", hparams["dim"]) gguf_writer.add_uint32(f"{arch_name}.depth", hparams["depth"]) # BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands num_bands = buffers.get("num_bands", hparams.get("num_bands", 60)) gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands) # STFT parameters gguf_writer.add_uint32(f"{arch_name}.stft_n_fft", hparams.get("stft_n_fft", 2048)) # Remove default for hop_length, must be present or fail/warn gguf_writer.add_uint32( f"{arch_name}.stft_hop_length", hparams.get("stft_hop_length", 441) ) gguf_writer.add_uint32( f"{arch_name}.stft_win_length", hparams.get("stft_win_length", 2048) ) gguf_writer.add_bool( f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False) ) gguf_writer.add_bool( f"{arch_name}.zero_dc", hparams.get("zero_dc", True) ) # Defaults to True in reference implementation # Architecture details gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1)) gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False)) gguf_writer.add_uint32( f"{arch_name}.sample_rate", hparams.get("sample_rate", 44100) ) gguf_writer.add_uint32( f"{arch_name}.time_transformer_depth", hparams.get("time_transformer_depth", 0), ) gguf_writer.add_uint32( f"{arch_name}.freq_transformer_depth", hparams.get("freq_transformer_depth", 0), ) gguf_writer.add_uint32( f"{arch_name}.linear_transformer_depth", hparams.get("linear_transformer_depth", 0), ) gguf_writer.add_uint32( f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1) ) gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("dim_head", 64)) gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("heads", 8)) gguf_writer.add_uint32( f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4) ) gguf_writer.add_bool( f"{arch_name}.skip_connection", hparams.get("skip_connection", False) ) # ========================================================================= # 3. Write Inference Defaults (Optional, can be overridden at runtime) # ========================================================================= print("Writing inference defaults...") inference_config = config_dict.get("inference", {}) audio_config = config_dict.get("audio", {}) # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size default_chunk_size = inference_config.get( "chunk_size", audio_config.get("chunk_size", 352800) ) # num_overlap: from inference section default_num_overlap = inference_config.get("num_overlap", 0) gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size) gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap) # ========================================================================= # 4. Write Buffers (Always FP32/I32) # ========================================================================= print("Writing buffers...") # freq_indices (int32) - may be torch.Tensor (MelBand) or np.ndarray (BS) fi = freq_indices.numpy() if hasattr(freq_indices, "numpy") else freq_indices gguf_writer.add_tensor("buffer_freq_indices", fi.astype(np.int32)) # num_bands_per_freq (int32) nbpf = ( num_bands_per_freq.numpy() if hasattr(num_bands_per_freq, "numpy") else num_bands_per_freq ) gguf_writer.add_tensor("buffer_num_bands_per_freq", nbpf.astype(np.int32)) # num_freqs_per_band (int32) nfpb = ( num_freqs_per_band.numpy() if hasattr(num_freqs_per_band, "numpy") else num_freqs_per_band ) gguf_writer.add_tensor("buffer_num_freqs_per_band", nfpb.astype(np.int32)) # ========================================================================= # 5. Write Weights (Mixed Quantization) # ========================================================================= print(f"Writing weights ({dtype} -> {target_qtype.name})...") print("Strategy: Quantize weights, Keep Norm/Bias as F32") n_tensors = 0 n_quantized = 0 warnings_list = [] for key, tensor in state_dict.items(): new_key = map_key_name(key) # Skip buffers if ( "freq_indices" in key or "num_bands_per_freq" in key or "num_freqs_per_band" in key ): continue data = tensor.numpy().astype(np.float32) # Decide whether to quantize is_quantized = False if target_qtype != GGMLQuantizationType.F32 and should_quantize(new_key): try: # Use gguf-py built-in quantization quantized_data = quantize(data, target_qtype) # Pass raw_dtype so GGUFWriter knows how to treat the byte array (for Q types) # or float array (for F16) gguf_writer.add_tensor(new_key, quantized_data, raw_dtype=target_qtype) is_quantized = True n_quantized += 1 except Exception as e: msg = f"Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}" warnings_list.append(msg) gguf_writer.add_tensor(new_key, data) else: # Keep as F32 gguf_writer.add_tensor(new_key, data) status = target_qtype.name if is_quantized else "F32" print(f" {new_key:<50} | {str(data.shape):<20} | {status}") n_tensors += 1 # ========================================================================= # 6. Write File # ========================================================================= print(f"\nWriting GGUF to {output_path}") gguf_writer.write_header_to_file() gguf_writer.write_kv_data_to_file() gguf_writer.write_tensors_to_file() gguf_writer.close() if warnings_list: print("\n" + "=" * 80) print( f"WARNING: {len(warnings_list)} tensors failed to quantize (fallback to F32):" ) for msg in warnings_list: print(f" - {msg}") print("=" * 80 + "\n") file_size = os.path.getsize(output_path) print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)") print(f"Output file size: {file_size / 1024 / 1024:.2f} MB") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Convert Mel-Band-Roformer checkpoint to GGUF format with Mixed Quantization", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_f16.gguf --dtype fp16 python convert_to_gguf.py --ckpt model.ckpt --config config.yaml --out model_q8.gguf --dtype q8_0 """, ) parser.add_argument( "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint" ) parser.add_argument("--config", type=str, required=True, help="Path to YAML config") parser.add_argument("--out", type=str, required=True, help="Output GGUF file path") parser.add_argument( "--dtype", type=str, default="fp32", choices=[ "fp32", "f32", "fp16", "f16", "q8_0", "q4_0", "q4_1", "q5_0", "q5_1", ], help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)", ) parser.add_argument( "--name", type=str, default=None, help="Model name (default: 'Mel-Band-Roformer Vocal Separator')", ) parser.add_argument( "--description", type=str, default=None, help="Model description (default: 'Audio source separation model for vocal extraction')", ) parser.add_argument( "--arch", choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer"], default=None, help="Architecture type (auto-detected if not specified)", ) args = parser.parse_args() convert( args.ckpt, args.out, args.config, args.dtype, args.name, args.description, args.arch, )