|
@@ -6,37 +6,39 @@ Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
|
|
|
Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
|
|
Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
-import os
|
|
|
|
|
import argparse
|
|
import argparse
|
|
|
-import torch
|
|
|
|
|
|
|
+import os
|
|
|
|
|
+
|
|
|
|
|
+import gguf
|
|
|
|
|
+import librosa
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
+import torch
|
|
|
import yaml
|
|
import yaml
|
|
|
-import librosa
|
|
|
|
|
-from einops import repeat, reduce, rearrange
|
|
|
|
|
-import gguf
|
|
|
|
|
-from gguf.quants import quantize, GGMLQuantizationType
|
|
|
|
|
|
|
+from einops import rearrange, reduce, repeat
|
|
|
|
|
+from gguf.quants import GGMLQuantizationType, quantize
|
|
|
|
|
+from safetensors.torch import load_file as load_safetensors
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_architecture(config_dict):
|
|
def detect_architecture(config_dict):
|
|
|
"""
|
|
"""
|
|
|
Detect architecture from config.
|
|
Detect architecture from config.
|
|
|
- Returns: 'bs_roformer' or 'mel_band_roformer'
|
|
|
|
|
|
|
+ Returns: 'bs_roformer', 'bs_roformer_v2', or 'mel_band_roformer'
|
|
|
"""
|
|
"""
|
|
|
-
|
|
|
|
|
- # Check structural signatures in 'model' section
|
|
|
|
|
- model_config = config_dict.get("model", {})
|
|
|
|
|
|
|
+ model_config = config_dict.get("model", config_dict)
|
|
|
|
|
|
|
|
has_freqs = "freqs_per_bands" in model_config
|
|
has_freqs = "freqs_per_bands" in model_config
|
|
|
|
|
+ has_freqs_out = "freqs_per_bands_out" in model_config
|
|
|
has_num_bands = "num_bands" in model_config
|
|
has_num_bands = "num_bands" in model_config
|
|
|
|
|
|
|
|
|
|
+ if has_freqs and has_freqs_out:
|
|
|
|
|
+ return "bs_roformer_v2"
|
|
|
if has_freqs:
|
|
if has_freqs:
|
|
|
return "bs_roformer"
|
|
return "bs_roformer"
|
|
|
if has_num_bands:
|
|
if has_num_bands:
|
|
|
return "mel_band_roformer"
|
|
return "mel_band_roformer"
|
|
|
-
|
|
|
|
|
- # 3. If neither found, fail
|
|
|
|
|
|
|
+
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
- "Auto-detection failed: Config missing 'freqs_per_bands' (BS) or 'num_bands' (Mel-Band). "
|
|
|
|
|
|
|
+ "Auto-detection failed: Config missing 'freqs_per_bands'/'freqs_per_bands_out' (BS_V2), 'freqs_per_bands' (BS), or 'num_bands' (Mel-Band). "
|
|
|
"Please specify --arch manually."
|
|
"Please specify --arch manually."
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -46,6 +48,7 @@ def normalize_arch(arch: str) -> str:
|
|
|
mapping = {
|
|
mapping = {
|
|
|
"bs": "bs_roformer",
|
|
"bs": "bs_roformer",
|
|
|
"bs_roformer": "bs_roformer",
|
|
"bs_roformer": "bs_roformer",
|
|
|
|
|
+ "bs_roformer_v2": "bs_roformer_v2",
|
|
|
"mel": "mel_band_roformer",
|
|
"mel": "mel_band_roformer",
|
|
|
"mel_band": "mel_band_roformer",
|
|
"mel_band": "mel_band_roformer",
|
|
|
"mel_band_roformer": "mel_band_roformer",
|
|
"mel_band_roformer": "mel_band_roformer",
|
|
@@ -178,7 +181,7 @@ def generate_buffers(hparams, arch="mel_band_roformer"):
|
|
|
hparams: Model hyperparameters
|
|
hparams: Model hyperparameters
|
|
|
arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
|
|
arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
|
|
|
"""
|
|
"""
|
|
|
- if arch == "bs_roformer":
|
|
|
|
|
|
|
+ if arch == "bs_roformer" or arch == "bs_roformer_v2":
|
|
|
return generate_buffers_bs(hparams)
|
|
return generate_buffers_bs(hparams)
|
|
|
|
|
|
|
|
# Mel-Band-Roformer Logic
|
|
# Mel-Band-Roformer Logic
|
|
@@ -234,9 +237,9 @@ def generate_buffers(hparams, arch="mel_band_roformer"):
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
# Quantization Helper
|
|
# Quantization Helper
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
|
|
def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
|
|
@@ -297,9 +300,9 @@ def should_quantize(name: str) -> bool:
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
# Key Name Mapping
|
|
# Key Name Mapping
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
def map_key_name(key: str) -> str:
|
|
def map_key_name(key: str) -> str:
|
|
@@ -378,9 +381,9 @@ def map_key_name(key: str) -> str:
|
|
|
return key.replace(".", "_")
|
|
return key.replace(".", "_")
|
|
|
|
|
|
|
|
|
|
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
# Main Conversion
|
|
# Main Conversion
|
|
|
-# ============================================================================
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert(
|
|
def convert(
|
|
@@ -396,18 +399,21 @@ def convert(
|
|
|
Convert PyTorch checkpoint to GGUF format.
|
|
Convert PyTorch checkpoint to GGUF format.
|
|
|
"""
|
|
"""
|
|
|
print(f"Loading checkpoint: {ckpt_path}")
|
|
print(f"Loading checkpoint: {ckpt_path}")
|
|
|
- checkpoint = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
|
-
|
|
|
|
|
- if "state_dict" in checkpoint:
|
|
|
|
|
- state_dict = checkpoint["state_dict"]
|
|
|
|
|
- elif "model" in checkpoint:
|
|
|
|
|
- state_dict = checkpoint["model"]
|
|
|
|
|
|
|
+ if ckpt_path.endswith(".safetensors"):
|
|
|
|
|
+ state_dict = load_safetensors(ckpt_path)
|
|
|
else:
|
|
else:
|
|
|
- state_dict = checkpoint
|
|
|
|
|
|
|
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
|
|
|
+
|
|
|
|
|
+ if "state_dict" in checkpoint:
|
|
|
|
|
+ state_dict = checkpoint["state_dict"]
|
|
|
|
|
+ elif "model" in checkpoint:
|
|
|
|
|
+ state_dict = checkpoint["model"]
|
|
|
|
|
+ else:
|
|
|
|
|
+ state_dict = checkpoint
|
|
|
|
|
|
|
|
print(f"Loading config: {config_path}")
|
|
print(f"Loading config: {config_path}")
|
|
|
with open(config_path) as f:
|
|
with open(config_path) as f:
|
|
|
- config_dict = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
|
|
|
+ config_dict = yaml.safe_load(f)
|
|
|
|
|
|
|
|
# Detect architecture
|
|
# Detect architecture
|
|
|
if arch is None:
|
|
if arch is None:
|
|
@@ -423,7 +429,7 @@ def convert(
|
|
|
|
|
|
|
|
# Generate buffers
|
|
# Generate buffers
|
|
|
print("Generating buffers (standalone)...")
|
|
print("Generating buffers (standalone)...")
|
|
|
- buffers = generate_buffers(config_dict["model"], arch=arch)
|
|
|
|
|
|
|
+ buffers = generate_buffers(config_dict, arch=arch)
|
|
|
freq_indices = buffers["freq_indices"]
|
|
freq_indices = buffers["freq_indices"]
|
|
|
num_bands_per_freq = buffers["num_bands_per_freq"]
|
|
num_bands_per_freq = buffers["num_bands_per_freq"]
|
|
|
num_freqs_per_band = buffers["num_freqs_per_band"]
|
|
num_freqs_per_band = buffers["num_freqs_per_band"]
|
|
@@ -439,7 +445,7 @@ def convert(
|
|
|
print("Writing metadata...")
|
|
print("Writing metadata...")
|
|
|
|
|
|
|
|
# General metadata
|
|
# General metadata
|
|
|
- model_name = name if name else "Mel-Band-Roformer Separator"
|
|
|
|
|
|
|
+ model_name = name if name else "BSRoformer Separator"
|
|
|
model_description = description if description else "Music source separation model"
|
|
model_description = description if description else "Music source separation model"
|
|
|
gguf_writer.add_name(model_name)
|
|
gguf_writer.add_name(model_name)
|
|
|
gguf_writer.add_description(model_description)
|
|
gguf_writer.add_description(model_description)
|
|
@@ -457,6 +463,11 @@ def convert(
|
|
|
freqs_tuple = buffers["freqs_per_bands_tuple"]
|
|
freqs_tuple = buffers["freqs_per_bands_tuple"]
|
|
|
# Must be list for GGUFWriter
|
|
# Must be list for GGUFWriter
|
|
|
gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
|
|
gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
|
|
|
|
|
+
|
|
|
|
|
+ if arch_name == "bs_roformer_v2":
|
|
|
|
|
+ gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(config_dict["freqs_per_bands"]))
|
|
|
|
|
+ gguf_writer.add_array(f"{arch_name}.freqs_per_bands_out", list(config_dict["freqs_per_bands_out"]))
|
|
|
|
|
+
|
|
|
|
|
|
|
|
# Quantization version (required when quantized)
|
|
# Quantization version (required when quantized)
|
|
|
if target_qtype != GGMLQuantizationType.F32:
|
|
if target_qtype != GGMLQuantizationType.F32:
|
|
@@ -476,18 +487,12 @@ def convert(
|
|
|
# 2. Write Hyperparameters
|
|
# 2. Write Hyperparameters
|
|
|
# =========================================================================
|
|
# =========================================================================
|
|
|
print("Writing hyperparameters...")
|
|
print("Writing hyperparameters...")
|
|
|
- hparams = config_dict["model"]
|
|
|
|
|
|
|
+ hparams = config_dict
|
|
|
|
|
|
|
|
# Load state dict directly (no model class dependency)
|
|
# Load state dict directly (no model class dependency)
|
|
|
print(f"Loading checkpoint for architecture: {arch}")
|
|
print(f"Loading checkpoint for architecture: {arch}")
|
|
|
|
|
|
|
|
- raw_state_dict = None
|
|
|
|
|
- if "state_dict" in checkpoint:
|
|
|
|
|
- raw_state_dict = checkpoint["state_dict"]
|
|
|
|
|
- elif "model" in checkpoint:
|
|
|
|
|
- raw_state_dict = checkpoint["model"]
|
|
|
|
|
- else:
|
|
|
|
|
- raw_state_dict = checkpoint
|
|
|
|
|
|
|
+ raw_state_dict = state_dict
|
|
|
|
|
|
|
|
if raw_state_dict is None:
|
|
if raw_state_dict is None:
|
|
|
raise ValueError("Could not find state_dict in checkpoint")
|
|
raise ValueError("Could not find state_dict in checkpoint")
|
|
@@ -500,10 +505,10 @@ def convert(
|
|
|
state_dict[k] = v
|
|
state_dict[k] = v
|
|
|
|
|
|
|
|
# Architecture specific parameters
|
|
# Architecture specific parameters
|
|
|
- gguf_writer.add_uint32(f"{arch_name}.dim", hparams["dim"])
|
|
|
|
|
- gguf_writer.add_uint32(f"{arch_name}.depth", hparams["depth"])
|
|
|
|
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.dim", hparams["hidden_size"])
|
|
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.depth", hparams["num_hidden_layers"])
|
|
|
# BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
|
|
# BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
|
|
|
- num_bands = buffers.get("num_bands", hparams.get("num_bands", 60))
|
|
|
|
|
|
|
+ num_bands = buffers.get("num_bands", len(hparams.get("freqs_per_bands", [])))
|
|
|
gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
|
|
gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
|
|
|
|
|
|
|
|
# STFT parameters
|
|
# STFT parameters
|
|
@@ -519,24 +524,50 @@ def convert(
|
|
|
f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
|
|
f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
|
|
|
)
|
|
)
|
|
|
gguf_writer.add_bool(
|
|
gguf_writer.add_bool(
|
|
|
- f"{arch_name}.zero_dc", hparams.get("zero_dc", True)
|
|
|
|
|
- ) # Defaults to True in reference implementation
|
|
|
|
|
|
|
+ f"{arch_name}.zero_dc", hparams.get("zero_dc", True) # Defaults to True in reference implementation
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Architecture details
|
|
# Architecture details
|
|
|
gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
|
|
gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
|
|
|
gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
|
|
gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
|
|
|
gguf_writer.add_uint32(
|
|
gguf_writer.add_uint32(
|
|
|
- f"{arch_name}.sample_rate", hparams.get("sample_rate", 44100)
|
|
|
|
|
|
|
+ f"{arch_name}.sample_rate", hparams.get("wave_sample_rate", 44100)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- gguf_writer.add_uint32(
|
|
|
|
|
- f"{arch_name}.time_transformer_depth",
|
|
|
|
|
- hparams.get("time_transformer_depth", 0),
|
|
|
|
|
- )
|
|
|
|
|
- gguf_writer.add_uint32(
|
|
|
|
|
- f"{arch_name}.freq_transformer_depth",
|
|
|
|
|
- hparams.get("freq_transformer_depth", 0),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if arch_name == "bs_roformer_v2":
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.time_transformer_depth",
|
|
|
|
|
+ hparams.get("time_transformer_depth", 1),
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.freq_transformer_depth",
|
|
|
|
|
+ hparams.get("freq_transformer_depth", 1),
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.num_key_value_heads", hparams.get("num_key_value_heads", 4)
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.intermediate_size", hparams.get("intermediate_size", 1152)
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.num_input_channels", hparams.get("num_input_channels", 2)
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.band_proj_size", hparams.get("band_proj_size", 256)
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.register_token_num", hparams.get("register_token_num", 4)
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.time_transformer_depth",
|
|
|
|
|
+ hparams.get("time_transformer_depth", 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ gguf_writer.add_uint32(
|
|
|
|
|
+ f"{arch_name}.freq_transformer_depth",
|
|
|
|
|
+ hparams.get("freq_transformer_depth", 0),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
gguf_writer.add_uint32(
|
|
gguf_writer.add_uint32(
|
|
|
f"{arch_name}.linear_transformer_depth",
|
|
f"{arch_name}.linear_transformer_depth",
|
|
|
hparams.get("linear_transformer_depth", 0),
|
|
hparams.get("linear_transformer_depth", 0),
|
|
@@ -545,8 +576,8 @@ def convert(
|
|
|
gguf_writer.add_uint32(
|
|
gguf_writer.add_uint32(
|
|
|
f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
|
|
f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
|
|
|
)
|
|
)
|
|
|
- gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("dim_head", 64))
|
|
|
|
|
- gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("heads", 8))
|
|
|
|
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("head_dim", 64))
|
|
|
|
|
+ gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("num_attention_heads", 8))
|
|
|
gguf_writer.add_uint32(
|
|
gguf_writer.add_uint32(
|
|
|
f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
|
|
f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
|
|
|
)
|
|
)
|
|
@@ -563,11 +594,11 @@ def convert(
|
|
|
audio_config = config_dict.get("audio", {})
|
|
audio_config = config_dict.get("audio", {})
|
|
|
|
|
|
|
|
# chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
|
|
# chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
|
|
|
- default_chunk_size = inference_config.get(
|
|
|
|
|
- "chunk_size", audio_config.get("chunk_size", 352800)
|
|
|
|
|
|
|
+ default_chunk_size = hparams.get(
|
|
|
|
|
+ "wave_chunk_size", 352800
|
|
|
)
|
|
)
|
|
|
# num_overlap: from inference section
|
|
# num_overlap: from inference section
|
|
|
- default_num_overlap = inference_config.get("num_overlap", 0)
|
|
|
|
|
|
|
+ default_num_overlap = inference_config.get("num_overlap", 2)
|
|
|
|
|
|
|
|
gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
|
|
gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
|
|
|
gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
|
|
gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
|
|
@@ -679,7 +710,7 @@ Examples:
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
|
|
"--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
|
|
|
)
|
|
)
|
|
|
- parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
|
|
|
|
|
|
|
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML or JSON config")
|
|
|
parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
|
|
parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--dtype",
|
|
"--dtype",
|
|
@@ -702,7 +733,7 @@ Examples:
|
|
|
"--name",
|
|
"--name",
|
|
|
type=str,
|
|
type=str,
|
|
|
default=None,
|
|
default=None,
|
|
|
- help="Model name (default: 'Mel-Band-Roformer Vocal Separator')",
|
|
|
|
|
|
|
+ help="Model name (default: 'BSRoformer Vocal Separator')",
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--description",
|
|
"--description",
|
|
@@ -712,7 +743,7 @@ Examples:
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--arch",
|
|
"--arch",
|
|
|
- choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer"],
|
|
|
|
|
|
|
+ choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer", "bs_roformer_v2"],
|
|
|
default=None,
|
|
default=None,
|
|
|
help="Architecture type (auto-detected if not specified)",
|
|
help="Architecture type (auto-detected if not specified)",
|
|
|
)
|
|
)
|