|
|
@@ -17,7 +17,172 @@ import gguf
|
|
|
from gguf.quants import quantize, GGMLQuantizationType
|
|
|
|
|
|
|
|
|
-def generate_buffers(hparams):
|
|
|
+def detect_architecture(config_dict):
|
|
|
+ """
|
|
|
+ Detect architecture from config.
|
|
|
+ Returns: 'bs_roformer' or 'mel_band_roformer'
|
|
|
+ """
|
|
|
+
|
|
|
+ # Check structural signatures in 'model' section
|
|
|
+ model_config = config_dict.get("model", {})
|
|
|
+
|
|
|
+ has_freqs = "freqs_per_bands" in model_config
|
|
|
+ has_num_bands = "num_bands" in model_config
|
|
|
+
|
|
|
+ if has_freqs:
|
|
|
+ return "bs_roformer"
|
|
|
+ if has_num_bands:
|
|
|
+ return "mel_band_roformer"
|
|
|
+
|
|
|
+ # 3. If neither found, fail
|
|
|
+ raise ValueError(
|
|
|
+ "Auto-detection failed: Config missing 'freqs_per_bands' (BS) or 'num_bands' (Mel-Band). "
|
|
|
+ "Please specify --arch manually."
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def normalize_arch(arch: str) -> str:
|
|
|
+ """Normalize architecture name to full GGUF name."""
|
|
|
+ mapping = {
|
|
|
+ "bs": "bs_roformer",
|
|
|
+ "bs_roformer": "bs_roformer",
|
|
|
+ "mel": "mel_band_roformer",
|
|
|
+ "mel_band": "mel_band_roformer",
|
|
|
+ "mel_band_roformer": "mel_band_roformer",
|
|
|
+ }
|
|
|
+ result = mapping.get(arch.lower())
|
|
|
+ if result is None:
|
|
|
+ raise ValueError(
|
|
|
+ f"Unknown architecture: '{arch}'. Supported: {list(mapping.keys())}"
|
|
|
+ )
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def generate_buffers_bs(hparams):
|
|
|
+ """BS Roformer: 从 freqs_per_bands 元组生成缓冲区"""
|
|
|
+ # Default from bs_roformer.py
|
|
|
+ DEFAULT_FREQS_PER_BANDS = (
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 2,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 4,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 12,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 24,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 48,
|
|
|
+ 128,
|
|
|
+ 129,
|
|
|
+ )
|
|
|
+
|
|
|
+ freqs_per_bands = hparams.get("freqs_per_bands", DEFAULT_FREQS_PER_BANDS)
|
|
|
+ stereo = hparams.get("stereo", False)
|
|
|
+ audio_channels = 2 if stereo else 1
|
|
|
+
|
|
|
+ # Validate
|
|
|
+ stft_n_fft = hparams.get("stft_n_fft", 2048)
|
|
|
+ expected_freqs = stft_n_fft // 2 + 1
|
|
|
+
|
|
|
+ # Check sum
|
|
|
+ sum_freqs = sum(freqs_per_bands)
|
|
|
+ if sum_freqs != expected_freqs:
|
|
|
+ print(
|
|
|
+ f"[WARNING] sum(freqs_per_bands)={sum_freqs} != expected {expected_freqs}. Adjusting last band..."
|
|
|
+ )
|
|
|
+ # Note: In C++ logic relying on exact match might be strict, but let's warn for now.
|
|
|
+ # Actually BS Roformer paper/code implies strict match for STFT reconstruction.
|
|
|
+
|
|
|
+ num_bands = len(freqs_per_bands)
|
|
|
+
|
|
|
+ freqs_per_bands_with_complex = tuple(
|
|
|
+ 2 * f * audio_channels for f in freqs_per_bands
|
|
|
+ )
|
|
|
+
|
|
|
+ # num_freqs_per_band: i32 array
|
|
|
+ num_freqs_per_band = np.array(freqs_per_bands, dtype=np.int32)
|
|
|
+
|
|
|
+ # BS doesn't use freq_indices re-indexing, but to keep compatible file structure
|
|
|
+ # we create dummy full-range indices.
|
|
|
+ total_freqs_stereo = expected_freqs * audio_channels
|
|
|
+ freq_indices = np.arange(total_freqs_stereo, dtype=np.int32)
|
|
|
+ num_bands_per_freq = np.ones(expected_freqs, dtype=np.int32)
|
|
|
+
|
|
|
+ print(f"Generated BS buffers: {num_bands} bands, {len(freq_indices)} indices")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "freq_indices": freq_indices,
|
|
|
+ "num_freqs_per_band": num_freqs_per_band,
|
|
|
+ "num_bands_per_freq": num_bands_per_freq,
|
|
|
+ "num_bands": num_bands,
|
|
|
+ "freqs_per_bands_with_complex": freqs_per_bands_with_complex,
|
|
|
+ "freqs_per_bands_tuple": freqs_per_bands, # Keep raw tuple for metadata
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def generate_buffers(hparams, arch="mel_band_roformer"):
|
|
|
+ """
|
|
|
+ Generate buffers for the specified architecture.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ hparams: Model hyperparameters
|
|
|
+ arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
|
|
|
+ """
|
|
|
+ if arch == "bs_roformer":
|
|
|
+ return generate_buffers_bs(hparams)
|
|
|
+
|
|
|
+ # Mel-Band-Roformer Logic
|
|
|
+ # ------------------------------------------------------------------------
|
|
|
"""
|
|
|
Generate the buffers (freq_indices, num_bands_per_freq, etc.)
|
|
|
mimicking the logic in MelBandRoformer.__init__.
|
|
|
@@ -206,6 +371,10 @@ def map_key_name(key: str) -> str:
|
|
|
layer_idx = parts[5] # 0, 2, 4
|
|
|
return f"mask_est.{est_idx}.freq.{freq_idx}.mlp.{layer_idx}.{suffix}"
|
|
|
|
|
|
+ # Final Norm
|
|
|
+ if key.startswith("final_norm"):
|
|
|
+ return f"final_norm.{suffix}"
|
|
|
+
|
|
|
return key.replace(".", "_")
|
|
|
|
|
|
|
|
|
@@ -219,8 +388,9 @@ def convert(
|
|
|
output_path: str,
|
|
|
config_path: str,
|
|
|
dtype: str = "fp32",
|
|
|
- name: str = None,
|
|
|
- description: str = None,
|
|
|
+ name: str | None = None,
|
|
|
+ description: str | None = None,
|
|
|
+ arch: str | None = None,
|
|
|
):
|
|
|
"""
|
|
|
Convert PyTorch checkpoint to GGUF format.
|
|
|
@@ -239,15 +409,29 @@ def convert(
|
|
|
with open(config_path) as f:
|
|
|
config_dict = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
|
|
+ # Detect architecture
|
|
|
+ if arch is None:
|
|
|
+ try:
|
|
|
+ arch = detect_architecture(config_dict)
|
|
|
+ print(f"Auto-detected architecture: {arch}")
|
|
|
+ except ValueError as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ # Normalize provided arch to full name
|
|
|
+ arch = normalize_arch(arch)
|
|
|
+
|
|
|
# Generate buffers
|
|
|
print("Generating buffers (standalone)...")
|
|
|
- buffers = generate_buffers(config_dict["model"])
|
|
|
+ buffers = generate_buffers(config_dict["model"], arch=arch)
|
|
|
freq_indices = buffers["freq_indices"]
|
|
|
num_bands_per_freq = buffers["num_bands_per_freq"]
|
|
|
num_freqs_per_band = buffers["num_freqs_per_band"]
|
|
|
|
|
|
+ arch_name = arch
|
|
|
+
|
|
|
# Create GGUF writer
|
|
|
- gguf_writer = gguf.GGUFWriter(output_path, "mel_band_roformer")
|
|
|
+ gguf_writer = gguf.GGUFWriter(output_path, arch_name)
|
|
|
|
|
|
# =========================================================================
|
|
|
# 1. Write Standard GGUF Metadata
|
|
|
@@ -266,6 +450,14 @@ def convert(
|
|
|
|
|
|
gguf_writer.add_file_type(file_type_id)
|
|
|
|
|
|
+ # Write Architecture
|
|
|
+ # gguf_writer.add_string(f"{arch_name}.architecture", arch) # Redundant with general.architecture
|
|
|
+
|
|
|
+ if arch_name == "bs_roformer" and "freqs_per_bands_tuple" in buffers:
|
|
|
+ freqs_tuple = buffers["freqs_per_bands_tuple"]
|
|
|
+ # Must be list for GGUFWriter
|
|
|
+ gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
|
|
|
+
|
|
|
# Quantization version (required when quantized)
|
|
|
if target_qtype != GGMLQuantizationType.F32:
|
|
|
gguf_writer.add_quantization_version(2)
|
|
|
@@ -286,59 +478,80 @@ def convert(
|
|
|
print("Writing hyperparameters...")
|
|
|
hparams = config_dict["model"]
|
|
|
|
|
|
+ # Load state dict directly (no model class dependency)
|
|
|
+ print(f"Loading checkpoint for architecture: {arch}")
|
|
|
+
|
|
|
+ raw_state_dict = None
|
|
|
+ if "state_dict" in checkpoint:
|
|
|
+ raw_state_dict = checkpoint["state_dict"]
|
|
|
+ elif "model" in checkpoint:
|
|
|
+ raw_state_dict = checkpoint["model"]
|
|
|
+ else:
|
|
|
+ raw_state_dict = checkpoint
|
|
|
+
|
|
|
+ if raw_state_dict is None:
|
|
|
+ raise ValueError("Could not find state_dict in checkpoint")
|
|
|
+
|
|
|
+ # Clean up state dict (handle DDP "module." prefix)
|
|
|
+ state_dict = {}
|
|
|
+ for k, v in raw_state_dict.items():
|
|
|
+ if k.startswith("module."):
|
|
|
+ k = k[7:]
|
|
|
+ state_dict[k] = v
|
|
|
+
|
|
|
# Architecture specific parameters
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.dim", hparams["dim"])
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.depth", hparams["depth"])
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.num_bands", hparams["num_bands"])
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.dim", hparams["dim"])
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.depth", hparams["depth"])
|
|
|
+ # BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
|
|
|
+ num_bands = buffers.get("num_bands", hparams.get("num_bands", 60))
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
|
|
|
|
|
|
# STFT parameters
|
|
|
- gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.stft_n_fft", hparams.get("stft_n_fft", 2048)
|
|
|
- )
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.stft_n_fft", hparams.get("stft_n_fft", 2048))
|
|
|
# Remove default for hop_length, must be present or fail/warn
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.stft_hop_length", hparams.get("stft_hop_length", 441)
|
|
|
+ f"{arch_name}.stft_hop_length", hparams.get("stft_hop_length", 441)
|
|
|
)
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.stft_win_length", hparams.get("stft_win_length", 2048)
|
|
|
+ f"{arch_name}.stft_win_length", hparams.get("stft_win_length", 2048)
|
|
|
)
|
|
|
gguf_writer.add_bool(
|
|
|
- "mel_band_roformer.stft_normalized", hparams.get("stft_normalized", False)
|
|
|
+ f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
|
|
|
)
|
|
|
gguf_writer.add_bool(
|
|
|
- "mel_band_roformer.zero_dc", hparams.get("zero_dc", True)
|
|
|
+ f"{arch_name}.zero_dc", hparams.get("zero_dc", True)
|
|
|
) # Defaults to True in reference implementation
|
|
|
|
|
|
# Architecture details
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.num_stems", hparams.get("num_stems", 1))
|
|
|
- gguf_writer.add_bool("mel_band_roformer.stereo", hparams.get("stereo", False))
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
|
|
|
+ gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.sample_rate", hparams.get("sample_rate", 44100)
|
|
|
+ f"{arch_name}.sample_rate", hparams.get("sample_rate", 44100)
|
|
|
)
|
|
|
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.time_transformer_depth",
|
|
|
+ f"{arch_name}.time_transformer_depth",
|
|
|
hparams.get("time_transformer_depth", 0),
|
|
|
)
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.freq_transformer_depth",
|
|
|
+ f"{arch_name}.freq_transformer_depth",
|
|
|
hparams.get("freq_transformer_depth", 0),
|
|
|
)
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.linear_transformer_depth",
|
|
|
+ f"{arch_name}.linear_transformer_depth",
|
|
|
hparams.get("linear_transformer_depth", 0),
|
|
|
)
|
|
|
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
|
|
|
+ f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
|
|
|
)
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.dim_head", hparams.get("dim_head", 64))
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.heads", hparams.get("heads", 8))
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("dim_head", 64))
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("heads", 8))
|
|
|
gguf_writer.add_uint32(
|
|
|
- "mel_band_roformer.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
|
|
|
+ f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
|
|
|
)
|
|
|
gguf_writer.add_bool(
|
|
|
- "mel_band_roformer.skip_connection", hparams.get("skip_connection", False)
|
|
|
+ f"{arch_name}.skip_connection", hparams.get("skip_connection", False)
|
|
|
)
|
|
|
|
|
|
# =========================================================================
|
|
|
@@ -356,24 +569,31 @@ def convert(
|
|
|
# num_overlap: from inference section
|
|
|
default_num_overlap = inference_config.get("num_overlap", 0)
|
|
|
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.default_chunk_size", default_chunk_size)
|
|
|
- gguf_writer.add_uint32("mel_band_roformer.default_num_overlap", default_num_overlap)
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
|
|
|
|
|
|
# =========================================================================
|
|
|
# 4. Write Buffers (Always FP32/I32)
|
|
|
# =========================================================================
|
|
|
print("Writing buffers...")
|
|
|
|
|
|
- # freq_indices (int32)
|
|
|
- gguf_writer.add_tensor("buffer_freq_indices", freq_indices.numpy().astype(np.int32))
|
|
|
+ # freq_indices (int32) - may be torch.Tensor (MelBand) or np.ndarray (BS)
|
|
|
+ fi = freq_indices.numpy() if hasattr(freq_indices, "numpy") else freq_indices
|
|
|
+ gguf_writer.add_tensor("buffer_freq_indices", fi.astype(np.int32))
|
|
|
# num_bands_per_freq (int32)
|
|
|
- gguf_writer.add_tensor(
|
|
|
- "buffer_num_bands_per_freq", num_bands_per_freq.numpy().astype(np.int32)
|
|
|
+ nbpf = (
|
|
|
+ num_bands_per_freq.numpy()
|
|
|
+ if hasattr(num_bands_per_freq, "numpy")
|
|
|
+ else num_bands_per_freq
|
|
|
)
|
|
|
+ gguf_writer.add_tensor("buffer_num_bands_per_freq", nbpf.astype(np.int32))
|
|
|
# num_freqs_per_band (int32)
|
|
|
- gguf_writer.add_tensor(
|
|
|
- "buffer_num_freqs_per_band", num_freqs_per_band.numpy().astype(np.int32)
|
|
|
+ nfpb = (
|
|
|
+ num_freqs_per_band.numpy()
|
|
|
+ if hasattr(num_freqs_per_band, "numpy")
|
|
|
+ else num_freqs_per_band
|
|
|
)
|
|
|
+ gguf_writer.add_tensor("buffer_num_freqs_per_band", nfpb.astype(np.int32))
|
|
|
|
|
|
# =========================================================================
|
|
|
# 5. Write Weights (Mixed Quantization)
|
|
|
@@ -384,6 +604,8 @@ def convert(
|
|
|
n_tensors = 0
|
|
|
n_quantized = 0
|
|
|
|
|
|
+ warnings_list = []
|
|
|
+
|
|
|
for key, tensor in state_dict.items():
|
|
|
new_key = map_key_name(key)
|
|
|
|
|
|
@@ -410,9 +632,8 @@ def convert(
|
|
|
is_quantized = True
|
|
|
n_quantized += 1
|
|
|
except Exception as e:
|
|
|
- print(
|
|
|
- f"Warning: Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
|
|
|
- )
|
|
|
+ msg = f"Failed to quantize {new_key} to {target_qtype.name}, falling back to F32. Error: {e}"
|
|
|
+ warnings_list.append(msg)
|
|
|
gguf_writer.add_tensor(new_key, data)
|
|
|
else:
|
|
|
# Keep as F32
|
|
|
@@ -431,6 +652,15 @@ def convert(
|
|
|
gguf_writer.write_tensors_to_file()
|
|
|
gguf_writer.close()
|
|
|
|
|
|
+ if warnings_list:
|
|
|
+ print("\n" + "=" * 80)
|
|
|
+ print(
|
|
|
+ f"WARNING: {len(warnings_list)} tensors failed to quantize (fallback to F32):"
|
|
|
+ )
|
|
|
+ for msg in warnings_list:
|
|
|
+ print(f" - {msg}")
|
|
|
+ print("=" * 80 + "\n")
|
|
|
+
|
|
|
file_size = os.path.getsize(output_path)
|
|
|
print(f"\nDone! Converted {n_tensors} tensors ({n_quantized} quantized)")
|
|
|
print(f"Output file size: {file_size / 1024 / 1024:.2f} MB")
|
|
|
@@ -480,6 +710,20 @@ Examples:
|
|
|
default=None,
|
|
|
help="Model description (default: 'Audio source separation model for vocal extraction')",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--arch",
|
|
|
+ choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer"],
|
|
|
+ default=None,
|
|
|
+ help="Architecture type (auto-detected if not specified)",
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- convert(args.ckpt, args.out, args.config, args.dtype, args.name, args.description)
|
|
|
+ convert(
|
|
|
+ args.ckpt,
|
|
|
+ args.out,
|
|
|
+ args.config,
|
|
|
+ args.dtype,
|
|
|
+ args.name,
|
|
|
+ args.description,
|
|
|
+ args.arch,
|
|
|
+ )
|