lainlives před 1 měsícem
rodič
revize
63b9b68da9
12 změnil soubory, kde provedl 607 přidání a 190 odebrání
  1. 6 1
      .gitignore
  2. 6 0
      .gitmodules
  3. 26 28
      CMakeLists.txt
  4. 7 27
      cli/main.cpp
  5. 1 0
      ggml
  6. 10 7
      include/bs_roformer/audio.h
  7. 39 36
      include/bs_roformer/inference.h
  8. 1 0
      libav
  9. 91 60
      scripts/convert_to_gguf.py
  10. 214 20
      src/audio.cpp
  11. 184 11
      src/model.cpp
  12. 22 0
      src/model.h

+ 6 - 1
.gitignore

@@ -1,2 +1,7 @@
 build*
-tests/data
+tests/data
+testdata/*
+rel
+rel/*
+ggml/*
+libav/*

+ 6 - 0
.gitmodules

@@ -0,0 +1,6 @@
+[submodule "libav"]
+	path = libav
+	url = https://github.com/libav/libav
+[submodule "ggml"]
+	path = ggml
+	url = https://github.com/ggml-org/ggml/

+ 26 - 28
CMakeLists.txt

@@ -8,29 +8,30 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
 # Build Options
 #================================================
 
-option(GGML_CUDA "Enable CUDA backend" ON)
-option(BSR_BUILD_TESTS "Build tests" OFF)
-option(BSR_BUILD_CLI "Build CLI application" ON)
+option(BSR_BUILD_CLI ON)
+option(BSR_BUILD_TESTS OFF)
 
+option(GGML_CUDA OFF) # cuda
+option(GGML_CUDA_FA ON)  # leave it on
+option(GGML_CUDA_NO_VMM ON)  # not needed
+
+option(GGML_HIP OFF)   # rocm
+option(GGML_HIP_RCCL  OFF) # not needed
+option(GGML_HIP_GRAPH  OFF) # definitely not needed
+option(GGML_HIP_ROCWMMA_FATTN ON) # turn it off if you have rocm flash attention issues
+option(GGML_HIP_NO_VMM ON) # Breaks support in most cards right now leave it on
+
+
+
+option(GGML_CUDA_FA_ALL_QUANT ON) # leave it on
 #================================================
 # Dependencies - GGML (Flexible Resolution)
 #================================================
 
-# Strategy: Allow ggml to be shared across multiple projects
-# 1. Check if ggml target already exists (e.g., from parent project like whisper.cpp)
-# 2. If not, try to find ggml via CMAKE_PREFIX_PATH or GGML_DIR
-# 3. If not found, use local ggml (submodule or sibling directory)
-
 if(NOT TARGET ggml)
-    # Try to find ggml package first (for system-wide or parent project installation)
     find_package(ggml QUIET CONFIG)
     
     if(NOT ggml_FOUND)
-        # ggml not found as package, look for source directory
-        # Priority 1: GGML_DIR variable (explicitly set by user or parent project)
-        # Priority 2: Submodule in ggml/
-        # Priority 3: Sibling directory ../ggml
-        
         if(DEFINED GGML_DIR)
             set(GGML_PATH "${GGML_DIR}")
             message(STATUS "Using GGML from GGML_DIR: ${GGML_PATH}")
@@ -50,7 +51,6 @@ if(NOT TARGET ggml)
             )
         endif()
         
-        # Add ggml as subdirectory
         add_subdirectory(${GGML_PATH} ggml EXCLUDE_FROM_ALL)
     else()
         message(STATUS "Using GGML from installed package")
@@ -81,8 +81,6 @@ target_include_directories(bs_roformer PRIVATE
 
 target_link_libraries(bs_roformer PUBLIC ggml)
 if(GGML_CUDA AND TARGET ggml-cuda)
-    # Fix for CI: Link against CUDA stubs if the driver is not present
-    # This prevents errors like "libcuda.so.1 needed by ... not found" during linking
     find_package(CUDAToolkit REQUIRED)
     if(TARGET CUDA::cuda_driver)
         target_link_libraries(bs_roformer PUBLIC CUDA::cuda_driver)
@@ -90,12 +88,7 @@ if(GGML_CUDA AND TARGET ggml-cuda)
     endif()
 endif()
 
-# Compiler options
-if(MSVC)
-    target_compile_options(bs_roformer PRIVATE /W3 /utf-8)
-else()
-    target_compile_options(bs_roformer PRIVATE -Wall -Wextra)
-endif()
+target_compile_options(bs_roformer PRIVATE -Wall -Wextra)
 
 # OpenMP support
 find_package(OpenMP)
@@ -157,7 +150,6 @@ endfunction()
 #================================================
 
 if(BSR_BUILD_CLI)
-    # audio.cpp implements AudioFile utilities (using dr_wav)
     add_executable(bs_roformer-cli 
         cli/main.cpp 
         src/audio.cpp
@@ -166,11 +158,17 @@ if(BSR_BUILD_CLI)
     target_include_directories(bs_roformer-cli PRIVATE 
         src 
         third_party
+        ${CMAKE_SOURCE_DIR}/libav
     )
     
-    if(MSVC)
-        target_compile_options(bs_roformer-cli PRIVATE /W3 /utf-8)
-    endif()
+    # Link against libav libraries from submodule and system dependencies
+    target_link_libraries(bs_roformer-cli PRIVATE
+        ${CMAKE_SOURCE_DIR}/libav/libavformat/libavformat.a
+        ${CMAKE_SOURCE_DIR}/libav/libavcodec/libavcodec.a
+        ${CMAKE_SOURCE_DIR}/libav/libavresample/libavresample.a
+        ${CMAKE_SOURCE_DIR}/libav/libavutil/libavutil.a
+        z bz2
+    )
     
     bsr_copy_ggml_runtime_dlls(bs_roformer-cli)
 endif()
@@ -185,4 +183,4 @@ if(BSR_BUILD_TESTS)
     message(STATUS "Tests: ENABLED")
 else()
     message(STATUS "Tests: DISABLED (use -DBSR_BUILD_TESTS=ON to enable)")
-endif()
+endif()

+ 7 - 27
cli/main.cpp

@@ -6,7 +6,10 @@
 #include <cstdlib>
 
 void print_usage(const char* program_name) {
-    std::cerr << "Usage: " << program_name << " <model.gguf> <input.wav> <output.wav> [options]" << std::endl;
+    std::cerr << "Usage: " << program_name << " <model.gguf> <input_audio> <output.wav> [options]" << std::endl;
+    std::cerr << std::endl;
+    std::cerr << "Input audio can be any common format (WAV, MP3, FLAC, OGG, etc.)" << std::endl;
+    std::cerr << "Audio is automatically resampled to 44100 Hz if needed." << std::endl;
     std::cerr << std::endl;
     std::cerr << "Options:" << std::endl;
     std::cerr << "  --chunk-size <N>   Chunk size in samples (default: from model, fallback 352800)" << std::endl;
@@ -94,31 +97,8 @@ int main(int argc, char* argv[]) {
                   << input_audio.channels << " channels, " 
                   << input_audio.sampleRate << " Hz" << std::endl;
 
-        // 1. Check Sample Rate
-        int required_sr = engine.GetSampleRate();
-        std::cout << "Model expects sample rate: " << required_sr << " Hz" << std::endl;
-
-        if (input_audio.sampleRate != required_sr) {
-            throw std::runtime_error("Input audio sample rate must be " + std::to_string(required_sr) + 
-                                     " Hz. Current: " + std::to_string(input_audio.sampleRate));
-        }
-
-        // 2. Check Channels & Auto-Expand Mono
-        if (input_audio.channels == 1) {
-             std::cout << "[Info] Input is Mono. Expanding to Stereo..." << std::endl;
-             std::vector<float> stereo_data(input_audio.samples * 2);
-             for(size_t i=0; i<input_audio.samples; ++i) {
-                 stereo_data[i*2 + 0] = input_audio.data[i];
-                 stereo_data[i*2 + 1] = input_audio.data[i];
-             }
-             input_audio.data = std::move(stereo_data);
-             input_audio.channels = 2;
-             input_audio.samples *= 2;
-        } else if (input_audio.channels != 2) {
-             // We can either reject or try to process first 2 channels? 
-             // Ideally reject to be safer, or warn.
-             throw std::runtime_error("Input audio must be Stereo (2 channels) or Mono (1 channel). Current: " + std::to_string(input_audio.channels));
-        }
+        // AudioFile::Load automatically resamples to 44100 Hz and converts to stereo
+        // No need for manual sample rate check or mono expansion
 
         std::cout << "Processing with chunk_size=" << chunk_size 
                   << ", overlap=" << num_overlap << std::endl;
@@ -167,7 +147,7 @@ int main(int argc, char* argv[]) {
             AudioBuffer output_audio_buf;
             output_audio_buf.data = std::move(output_stems[i]); // Move to avoid copy
             output_audio_buf.channels = 2; // Output is always stereo
-            output_audio_buf.sampleRate = required_sr;
+            output_audio_buf.sampleRate = 44100;
             output_audio_buf.samples = output_audio_buf.data.size();
             
             std::cout << "Saving output stem " << i << ": " << current_output_path << std::endl;

+ 1 - 0
ggml

@@ -0,0 +1 @@
+Subproject commit 57ea0bc119d722d74594196cc5b494a34dd87be4

+ 10 - 7
include/bs_roformer/audio.h

@@ -16,23 +16,26 @@ struct AudioBuffer {
 
 /**
  * Audio file I/O utilities.
- * Supports WAV format (via dr_wav).
+ * Supports any common audio format (WAV, MP3, FLAC, OGG, etc.) via FFmpeg/libav.
+ * Automatically resamples to target sample rate if needed.
  */
 class AudioFile {
 public:
     /**
-     * Load audio from a WAV file.
-     * @param path Path to the WAV file
+     * Load audio from any common audio file format.
+     * Audio is automatically resampled to 44100 Hz and converted to stereo float32.
+     * @param path Path to the audio file
+     * @param target_sample_rate Target sample rate (default: 44100 Hz)
      * @return AudioBuffer containing the loaded audio data
-     * @throws std::runtime_error if the file cannot be opened
+     * @throws std::runtime_error if the file cannot be opened or decoded
      */
-    static AudioBuffer Load(const std::string& path);
+    static AudioBuffer Load(const std::string& path, int target_sample_rate = 44100);
     
     /**
-     * Save audio to a WAV file.
+     * Save audio to a WAV file (PCM float32).
      * @param path Path to save the WAV file
      * @param buffer AudioBuffer containing audio data to save
      * @throws std::runtime_error if the file cannot be written
      */
     static void Save(const std::string& path, const AudioBuffer& buffer);
-};
+};

+ 39 - 36
include/bs_roformer/inference.h

@@ -1,14 +1,17 @@
 #pragma once
-
-#include <vector>
-#include <string>
-#include <memory>
+#include <cstdint>
 #include <functional>
+#include <memory>
+#include <string>
+#include <vector>
 // Forward declaration
 class BSRoformer;
 
 // Forward declaration
-namespace ggml { struct context; struct cgraph; }
+namespace ggml {
+struct context;
+struct cgraph;
+}
 
 class Inference {
 public:
@@ -21,11 +24,11 @@ public:
     // Uses overlap-add chunking to handle long files
     // Process a full audio track (interleaved stereo float32)
     // Returns a vector of stems, where each stem is an interleaved stereo float vector
-    std::vector<std::vector<float>> Process(const std::vector<float>& input_audio, 
-                               int chunk_size = 352800, 
-                               int num_overlap = 2,
-                               std::function<void(float)> progress_callback = nullptr,
-                               CancelCallback cancel_callback = nullptr);
+    std::vector<std::vector<float>> Process(const std::vector<float>& input_audio,
+        int chunk_size = 352800,
+        int num_overlap = 2,
+        std::function<void(float)> progress_callback = nullptr,
+        CancelCallback cancel_callback = nullptr);
 
     // Low-level chunk processing (public for testing)
     std::vector<std::vector<float>> ProcessChunk(const std::vector<float>& chunk_audio);
@@ -39,29 +42,29 @@ public:
     // Static helper for Overlap-Add logic (matches Python exactly)
     // model_func: input [samples], output [stems][samples] (interleaved stereo)
     using ModelCallback = std::function<std::vector<std::vector<float>>(const std::vector<float>&)>;
-    static std::vector<std::vector<float>> ProcessOverlapAdd(const std::vector<float>& input_audio, 
-                                                int chunk_size, 
-                                                int num_overlap,
-                                                ModelCallback model_func,
-                                                std::function<void(float)> progress_callback = nullptr,
-                                                CancelCallback cancel_callback = nullptr);
+    static std::vector<std::vector<float>> ProcessOverlapAdd(const std::vector<float>& input_audio,
+        int chunk_size,
+        int num_overlap,
+        ModelCallback model_func,
+        std::function<void(float)> progress_callback = nullptr,
+        CancelCallback cancel_callback = nullptr);
 
 private:
     // Pipelined Overlap-Add
-    std::vector<std::vector<float>> ProcessOverlapAddPipelined(const std::vector<float>& input_audio, 
-                                                  int chunk_size, 
-                                                  int num_overlap,
-                                                  std::function<void(float)> progress_callback,
-                                                  CancelCallback cancel_callback);
+    std::vector<std::vector<float>> ProcessOverlapAddPipelined(const std::vector<float>& input_audio,
+        int chunk_size,
+        int num_overlap,
+        std::function<void(float)> progress_callback,
+        CancelCallback cancel_callback);
 
 private:
     std::unique_ptr<BSRoformer> model_;
-    
+
     // Persistent Graph State
     struct ggml_context* ctx_ = nullptr;
     struct ggml_cgraph* gf_ = nullptr;
     struct ggml_gallocr* allocr_ = nullptr;
-    
+
     // Cached Input Tensors (owned by ctx_)
     struct ggml_tensor* input_tensor_ = nullptr;
     struct ggml_tensor* pos_time_ = nullptr;
@@ -78,30 +81,30 @@ private:
     // Pipelined State Data
     struct ChunkState {
         int id = -1;
-        std::vector<float> input_audio;       // Original chunk audio
-        std::vector<float> stft_flattened;    // [Prepared Input for GPU]
+        std::vector<float> input_audio; // Original chunk audio
+        std::vector<float> stft_flattened; // [Prepared Input for GPU]
         std::vector<std::vector<float>> stft_outputs; // Kept for reconstruction
         int n_frames = 0;
-        
-        std::vector<float> mask_output;       // Output from GPU
-        std::vector<std::vector<float>> final_audio;       // Result after ISTFT [stems][samples]
+
+        std::vector<float> mask_output; // Output from GPU
+        std::vector<std::vector<float>> final_audio; // Result after ISTFT [stems][samples]
     };
 
     // Helper to ensure graph is built for specific n_frames
     bool EnsureGraph(int n_frames);
 
     void ComputeSTFT(const std::vector<float>& input_audio,
-                     std::vector<std::vector<float>>& stft_outputs,
-                     int& n_frames);
-                     
+        std::vector<std::vector<float>>& stft_outputs,
+        int& n_frames);
+
     void PrepareModelInput(const std::vector<std::vector<float>>& stft_outputs,
-                           int n_frames,
-                           std::vector<float>& model_input_rearranged);
+        int n_frames,
+        std::vector<float>& model_input_rearranged);
 
     void PostProcessAndISTFT(const std::vector<float>& mask_output,
-                             const std::vector<std::vector<float>>& stft_outputs,
-                             int n_frames,
-                             std::vector<std::vector<float>>& output_audio);
+        const std::vector<std::vector<float>>& stft_outputs,
+        int n_frames,
+        std::vector<std::vector<float>>& output_audio);
 
     // Pipeline Steps
     std::shared_ptr<ChunkState> PreProcessChunk(const std::vector<float>& chunk_audio, int id);

+ 1 - 0
libav

@@ -0,0 +1 @@
+Subproject commit c4642788e83b0858bca449f9b6e71ddb015dfa5d

+ 91 - 60
scripts/convert_to_gguf.py

@@ -6,37 +6,39 @@ Supports quantization: FP32, FP16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1
 Mixed Quantization: Keeps Norms/Biases as FP32 to avoid CUDA alignment issues.
 """
 
-import os
 import argparse
-import torch
+import os
+
+import gguf
+import librosa
 import numpy as np
+import torch
 import yaml
-import librosa
-from einops import repeat, reduce, rearrange
-import gguf
-from gguf.quants import quantize, GGMLQuantizationType
+from einops import rearrange, reduce, repeat
+from gguf.quants import GGMLQuantizationType, quantize
+from safetensors.torch import load_file as load_safetensors
 
 
 def detect_architecture(config_dict):
     """
     Detect architecture from config.
-    Returns: 'bs_roformer' or 'mel_band_roformer'
+    Returns: 'bs_roformer', 'bs_roformer_v2', or 'mel_band_roformer'
     """
-
-    # Check structural signatures in 'model' section
-    model_config = config_dict.get("model", {})
+    model_config = config_dict.get("model", config_dict)
 
     has_freqs = "freqs_per_bands" in model_config
+    has_freqs_out = "freqs_per_bands_out" in model_config
     has_num_bands = "num_bands" in model_config
 
+    if has_freqs and has_freqs_out:
+        return "bs_roformer_v2"
     if has_freqs:
         return "bs_roformer"
     if has_num_bands:
         return "mel_band_roformer"
-
-    # 3. If neither found, fail
+    
     raise ValueError(
-        "Auto-detection failed: Config missing 'freqs_per_bands' (BS) or 'num_bands' (Mel-Band). "
+        "Auto-detection failed: Config missing 'freqs_per_bands'/'freqs_per_bands_out' (BS_V2), 'freqs_per_bands' (BS), or 'num_bands' (Mel-Band). "
         "Please specify --arch manually."
     )
 
@@ -46,6 +48,7 @@ def normalize_arch(arch: str) -> str:
     mapping = {
         "bs": "bs_roformer",
         "bs_roformer": "bs_roformer",
+        "bs_roformer_v2": "bs_roformer_v2",
         "mel": "mel_band_roformer",
         "mel_band": "mel_band_roformer",
         "mel_band_roformer": "mel_band_roformer",
@@ -178,7 +181,7 @@ def generate_buffers(hparams, arch="mel_band_roformer"):
         hparams: Model hyperparameters
         arch: Architecture name ('bs_roformer' or 'mel_band_roformer')
     """
-    if arch == "bs_roformer":
+    if arch == "bs_roformer" or arch == "bs_roformer_v2":
         return generate_buffers_bs(hparams)
 
     # Mel-Band-Roformer Logic
@@ -234,9 +237,9 @@ def generate_buffers(hparams, arch="mel_band_roformer"):
     }
 
 
-# ============================================================================
+# ============================================================================ 
 # Quantization Helper
-# ============================================================================
+# ============================================================================ 
 
 
 def get_target_quantization_type(dtype_str: str) -> GGMLQuantizationType:
@@ -297,9 +300,9 @@ def should_quantize(name: str) -> bool:
     return False
 
 
-# ============================================================================
+# ============================================================================ 
 # Key Name Mapping
-# ============================================================================
+# ============================================================================ 
 
 
 def map_key_name(key: str) -> str:
@@ -378,9 +381,9 @@ def map_key_name(key: str) -> str:
     return key.replace(".", "_")
 
 
-# ============================================================================
+# ============================================================================ 
 # Main Conversion
-# ============================================================================
+# ============================================================================ 
 
 
 def convert(
@@ -396,18 +399,21 @@ def convert(
     Convert PyTorch checkpoint to GGUF format.
     """
     print(f"Loading checkpoint: {ckpt_path}")
-    checkpoint = torch.load(ckpt_path, map_location="cpu")
-
-    if "state_dict" in checkpoint:
-        state_dict = checkpoint["state_dict"]
-    elif "model" in checkpoint:
-        state_dict = checkpoint["model"]
+    if ckpt_path.endswith(".safetensors"):
+        state_dict = load_safetensors(ckpt_path)
     else:
-        state_dict = checkpoint
+        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
+
+        if "state_dict" in checkpoint:
+            state_dict = checkpoint["state_dict"]
+        elif "model" in checkpoint:
+            state_dict = checkpoint["model"]
+        else:
+            state_dict = checkpoint
 
     print(f"Loading config: {config_path}")
     with open(config_path) as f:
-        config_dict = yaml.load(f, Loader=yaml.FullLoader)
+        config_dict = yaml.safe_load(f)
 
     # Detect architecture
     if arch is None:
@@ -423,7 +429,7 @@ def convert(
 
     # Generate buffers
     print("Generating buffers (standalone)...")
-    buffers = generate_buffers(config_dict["model"], arch=arch)
+    buffers = generate_buffers(config_dict, arch=arch)
     freq_indices = buffers["freq_indices"]
     num_bands_per_freq = buffers["num_bands_per_freq"]
     num_freqs_per_band = buffers["num_freqs_per_band"]
@@ -439,7 +445,7 @@ def convert(
     print("Writing metadata...")
 
     # General metadata
-    model_name = name if name else "Mel-Band-Roformer Separator"
+    model_name = name if name else "BSRoformer Separator"
     model_description = description if description else "Music source separation model"
     gguf_writer.add_name(model_name)
     gguf_writer.add_description(model_description)
@@ -457,6 +463,11 @@ def convert(
         freqs_tuple = buffers["freqs_per_bands_tuple"]
         # Must be list for GGUFWriter
         gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(freqs_tuple))
+    
+    if arch_name == "bs_roformer_v2":
+        gguf_writer.add_array(f"{arch_name}.freqs_per_bands", list(config_dict["freqs_per_bands"]))
+        gguf_writer.add_array(f"{arch_name}.freqs_per_bands_out", list(config_dict["freqs_per_bands_out"]))
+
 
     # Quantization version (required when quantized)
     if target_qtype != GGMLQuantizationType.F32:
@@ -476,18 +487,12 @@ def convert(
     # 2. Write Hyperparameters
     # =========================================================================
     print("Writing hyperparameters...")
-    hparams = config_dict["model"]
+    hparams = config_dict
 
     # Load state dict directly (no model class dependency)
     print(f"Loading checkpoint for architecture: {arch}")
 
-    raw_state_dict = None
-    if "state_dict" in checkpoint:
-        raw_state_dict = checkpoint["state_dict"]
-    elif "model" in checkpoint:
-        raw_state_dict = checkpoint["model"]
-    else:
-        raw_state_dict = checkpoint
+    raw_state_dict = state_dict
 
     if raw_state_dict is None:
         raise ValueError("Could not find state_dict in checkpoint")
@@ -500,10 +505,10 @@ def convert(
         state_dict[k] = v
 
     # Architecture specific parameters
-    gguf_writer.add_uint32(f"{arch_name}.dim", hparams["dim"])
-    gguf_writer.add_uint32(f"{arch_name}.depth", hparams["depth"])
+    gguf_writer.add_uint32(f"{arch_name}.dim", hparams["hidden_size"])
+    gguf_writer.add_uint32(f"{arch_name}.depth", hparams["num_hidden_layers"])
     # BS uses freqs_per_bands (no explicit num_bands), MelBand uses num_bands
-    num_bands = buffers.get("num_bands", hparams.get("num_bands", 60))
+    num_bands = buffers.get("num_bands", len(hparams.get("freqs_per_bands", [])))
     gguf_writer.add_uint32(f"{arch_name}.num_bands", num_bands)
 
     # STFT parameters
@@ -519,24 +524,50 @@ def convert(
         f"{arch_name}.stft_normalized", hparams.get("stft_normalized", False)
     )
     gguf_writer.add_bool(
-        f"{arch_name}.zero_dc", hparams.get("zero_dc", True)
-    )  # Defaults to True in reference implementation
+        f"{arch_name}.zero_dc", hparams.get("zero_dc", True) # Defaults to True in reference implementation
+    )
 
     # Architecture details
     gguf_writer.add_uint32(f"{arch_name}.num_stems", hparams.get("num_stems", 1))
     gguf_writer.add_bool(f"{arch_name}.stereo", hparams.get("stereo", False))
     gguf_writer.add_uint32(
-        f"{arch_name}.sample_rate", hparams.get("sample_rate", 44100)
+        f"{arch_name}.sample_rate", hparams.get("wave_sample_rate", 44100)
     )
 
-    gguf_writer.add_uint32(
-        f"{arch_name}.time_transformer_depth",
-        hparams.get("time_transformer_depth", 0),
-    )
-    gguf_writer.add_uint32(
-        f"{arch_name}.freq_transformer_depth",
-        hparams.get("freq_transformer_depth", 0),
-    )
+    if arch_name == "bs_roformer_v2":
+        gguf_writer.add_uint32(
+            f"{arch_name}.time_transformer_depth",
+            hparams.get("time_transformer_depth", 1),
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.freq_transformer_depth",
+            hparams.get("freq_transformer_depth", 1),
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.num_key_value_heads", hparams.get("num_key_value_heads", 4)
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.intermediate_size", hparams.get("intermediate_size", 1152)
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.num_input_channels", hparams.get("num_input_channels", 2)
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.band_proj_size", hparams.get("band_proj_size", 256)
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.register_token_num", hparams.get("register_token_num", 4)
+        )
+    else:
+        gguf_writer.add_uint32(
+            f"{arch_name}.time_transformer_depth",
+            hparams.get("time_transformer_depth", 0),
+        )
+        gguf_writer.add_uint32(
+            f"{arch_name}.freq_transformer_depth",
+            hparams.get("freq_transformer_depth", 0),
+        )
+
     gguf_writer.add_uint32(
         f"{arch_name}.linear_transformer_depth",
         hparams.get("linear_transformer_depth", 0),
@@ -545,8 +576,8 @@ def convert(
     gguf_writer.add_uint32(
         f"{arch_name}.mask_estimator_depth", hparams.get("mask_estimator_depth", 1)
     )
-    gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("dim_head", 64))
-    gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("heads", 8))
+    gguf_writer.add_uint32(f"{arch_name}.dim_head", hparams.get("head_dim", 64))
+    gguf_writer.add_uint32(f"{arch_name}.heads", hparams.get("num_attention_heads", 8))
     gguf_writer.add_uint32(
         f"{arch_name}.mlp_expansion_factor", hparams.get("mlp_expansion_factor", 4)
     )
@@ -563,11 +594,11 @@ def convert(
     audio_config = config_dict.get("audio", {})
 
     # chunk_size: prefer inference.chunk_size, fallback to audio.chunk_size
-    default_chunk_size = inference_config.get(
-        "chunk_size", audio_config.get("chunk_size", 352800)
+    default_chunk_size = hparams.get(
+        "wave_chunk_size", 352800
     )
     # num_overlap: from inference section
-    default_num_overlap = inference_config.get("num_overlap", 0)
+    default_num_overlap = inference_config.get("num_overlap", 2)
 
     gguf_writer.add_uint32(f"{arch_name}.default_chunk_size", default_chunk_size)
     gguf_writer.add_uint32(f"{arch_name}.default_num_overlap", default_num_overlap)
@@ -679,7 +710,7 @@ Examples:
     parser.add_argument(
         "--ckpt", type=str, required=True, help="Path to PyTorch checkpoint"
     )
-    parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
+    parser.add_argument("--config", type=str, required=True, help="Path to YAML or JSON config")
     parser.add_argument("--out", type=str, required=True, help="Output GGUF file path")
     parser.add_argument(
         "--dtype",
@@ -702,7 +733,7 @@ Examples:
         "--name",
         type=str,
         default=None,
-        help="Model name (default: 'Mel-Band-Roformer Vocal Separator')",
+        help="Model name (default: 'BSRoformer Vocal Separator')",
     )
     parser.add_argument(
         "--description",
@@ -712,7 +743,7 @@ Examples:
     )
     parser.add_argument(
         "--arch",
-        choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer"],
+        choices=["mel_band", "mel_band_roformer", "bs", "bs_roformer", "bs_roformer_v2"],
         default=None,
         help="Architecture type (auto-detected if not specified)",
     )

+ 214 - 20
src/audio.cpp

@@ -1,31 +1,225 @@
-#define DR_WAV_IMPLEMENTATION
-#include "dr_libs/dr_wav.h"
 #include "bs_roformer/audio.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include <libavformat/avformat.h>
+#include <libavcodec/avcodec.h>
+#include <libavutil/avutil.h>
+#include <libavutil/samplefmt.h>
+#include <libavutil/channel_layout.h>
+#include <libavutil/frame.h>
+#include <libavutil/log.h>
+#include <libavutil/mem.h>
+#include <libavutil/dict.h>
+#ifdef __cplusplus
+}
+#endif
 #include <iostream>
+#include <vector>
 
-AudioBuffer AudioFile::Load(const std::string& path) {
+#define DR_WAV_IMPLEMENTATION
+#include "dr_libs/dr_wav.h"
+
+static void InitFFmpeg() {
+    av_log_set_level(AV_LOG_ERROR);
+    av_register_all();
+}
+
+AudioBuffer AudioFile::Load(const std::string& path, int target_sample_rate) {
     AudioBuffer buffer;
-    drwav_uint64 totalPCMFrames;
+    buffer.channels = 0;
+    buffer.sampleRate = 0;
+    buffer.samples = 0;
+
+    AVFormatContext* fmt_ctx = nullptr;
+    AVCodecContext* codec_ctx = nullptr;
+    int audio_stream_index = -1;
     
-    float* pData = drwav_open_file_and_read_pcm_frames_f32(
-        path.c_str(), &buffer.channels, &buffer.sampleRate, &totalPCMFrames, NULL);
-        
-    if (!pData) {
+    InitFFmpeg();
+    
+    if (avformat_open_input(&fmt_ctx, path.c_str(), nullptr, nullptr) < 0) {
         throw std::runtime_error("Failed to open audio file: " + path);
     }
     
-    buffer.samples = totalPCMFrames * buffer.channels;
-    buffer.data.assign(pData, pData + buffer.samples);
-    drwav_free(pData, NULL);
+    if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) {
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to find stream info");
+    }
+    
+    for (unsigned int i = 0; i < fmt_ctx->nb_streams; i++) {
+        if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
+            audio_stream_index = i;
+            break;
+        }
+    }
+    
+    if (audio_stream_index == -1) {
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("No audio stream found in file: " + path);
+    }
+    
+    AVCodecParameters* codecpar = fmt_ctx->streams[audio_stream_index]->codecpar;
+    const AVCodec* codec = avcodec_find_decoder(codecpar->codec_id);
+    if (!codec) {
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Codec not found for audio stream");
+    }
     
-    // Validation
-    if (buffer.sampleRate != 44100) {
-        std::cerr << "Warning: Input sample rate is " << buffer.sampleRate 
-                  << " Hz. Model expects 44100 Hz." << std::endl;
-        // In a full implementation, we would resample here.
-        // For now, we warn.
+    codec_ctx = avcodec_alloc_context3(codec);
+    if (!codec_ctx) {
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to allocate codec context");
     }
     
+    if (avcodec_parameters_to_context(codec_ctx, codecpar) < 0) {
+        avcodec_free_context(&codec_ctx);
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to copy codec parameters");
+    }
+    
+    if (avcodec_open2(codec_ctx, codec, nullptr) < 0) {
+        avcodec_free_context(&codec_ctx);
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to open codec");
+    }
+    
+    int input_sample_rate = codec_ctx->sample_rate;
+    int input_channels = codec_ctx->channels;
+    
+    std::vector<float> audio_data;
+    AVPacket* pkt = av_packet_alloc();
+    if (!pkt) {
+        avcodec_free_context(&codec_ctx);
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to allocate packet");
+    }
+    
+    AVFrame* decoded_frame = av_frame_alloc();
+    if (!decoded_frame) {
+        av_packet_free(&pkt);
+        avcodec_free_context(&codec_ctx);
+        avformat_close_input(&fmt_ctx);
+        throw std::runtime_error("Failed to allocate frame");
+    }
+    
+    int ret;
+    while ((ret = av_read_frame(fmt_ctx, pkt)) >= 0) {
+        if (pkt->stream_index != audio_stream_index) {
+            av_packet_unref(pkt);
+            continue;
+        }
+        
+        ret = avcodec_send_packet(codec_ctx, pkt);
+        if (ret < 0) {
+            av_packet_unref(pkt);
+            continue;
+        }
+        
+        ret = avcodec_receive_frame(codec_ctx, decoded_frame);
+        if (ret >= 0) {
+            int nb_samples = decoded_frame->nb_samples;
+            int nb_channels = av_get_channel_layout_nb_channels(decoded_frame->channel_layout);
+            if (nb_channels <= 0) {
+                nb_channels = input_channels;
+            }
+            
+            std::vector<float> channel_data(nb_channels * nb_samples);
+            
+            int sample_fmt = decoded_frame->format;
+            
+            if (sample_fmt == AV_SAMPLE_FMT_FLT) {
+                float* data = (float*)decoded_frame->data[0];
+                for (int i = 0; i < nb_samples * nb_channels; i++) {
+                    channel_data[i] = data[i];
+                }
+            } else if (sample_fmt == AV_SAMPLE_FMT_FLTP) {
+                for (int c = 0; c < nb_channels; c++) {
+                    float* channel = (float*)decoded_frame->data[c];
+                    for (int i = 0; i < nb_samples; i++) {
+                        channel_data[c * nb_samples + i] = channel[i];
+                    }
+                }
+            } else if (sample_fmt == AV_SAMPLE_FMT_S16) {
+                int16_t* data = (int16_t*)decoded_frame->data[0];
+                for (int i = 0; i < nb_samples * nb_channels; i++) {
+                    channel_data[i] = data[i] / 32768.0f;
+                }
+            } else if (sample_fmt == AV_SAMPLE_FMT_S16P) {
+                for (int c = 0; c < nb_channels; c++) {
+                    int16_t* channel = (int16_t*)decoded_frame->data[c];
+                    for (int i = 0; i < nb_samples; i++) {
+                        channel_data[c * nb_samples + i] = channel[i] / 32768.0f;
+                    }
+                }
+            } else if (sample_fmt == AV_SAMPLE_FMT_S32) {
+                int32_t* data = (int32_t*)decoded_frame->data[0];
+                for (int i = 0; i < nb_samples * nb_channels; i++) {
+                    channel_data[i] = data[i] / 2147483648.0f;
+                }
+            } else if (sample_fmt == AV_SAMPLE_FMT_S32P) {
+                for (int c = 0; c < nb_channels; c++) {
+                    int32_t* channel = (int32_t*)decoded_frame->data[c];
+                    for (int i = 0; i < nb_samples; i++) {
+                        channel_data[c * nb_samples + i] = channel[i] / 2147483648.0f;
+                    }
+                }
+            } else {
+                float* data = (float*)decoded_frame->data[0];
+                for (int i = 0; i < nb_samples * nb_channels; i++) {
+                    channel_data[i] = data[i];
+                }
+            }
+            
+            int resampled_samples = (nb_samples * target_sample_rate) / input_sample_rate;
+            std::vector<float> resampled_data(nb_channels * resampled_samples);
+            
+            for (int c = 0; c < nb_channels; c++) {
+                for (int i = 0; i < resampled_samples; i++) {
+                    float src_idx = (float)i * input_sample_rate / target_sample_rate;
+                    int src_idx_int = (int)src_idx;
+                    float frac = src_idx - src_idx_int;
+                    
+                    if (src_idx_int + 1 < nb_samples) {
+                        resampled_data[c * resampled_samples + i] = 
+                            channel_data[c * nb_samples + src_idx_int] * (1.0f - frac) +
+                            channel_data[c * nb_samples + src_idx_int + 1] * frac;
+                    } else if (src_idx_int < nb_samples) {
+                        resampled_data[c * resampled_samples + i] = 
+                            channel_data[c * nb_samples + src_idx_int];
+                    } else {
+                        resampled_data[c * resampled_samples + i] = 0.0f;
+                    }
+                }
+            }
+            
+            if (nb_channels >= 2) {
+                for (int i = 0; i < resampled_samples; i++) {
+                    audio_data.push_back(resampled_data[0 * resampled_samples + i]);
+                    audio_data.push_back(resampled_data[1 * resampled_samples + i]);
+                }
+            } else {
+                for (int i = 0; i < resampled_samples; i++) {
+                    audio_data.push_back(resampled_data[0 * resampled_samples + i]);
+                    audio_data.push_back(resampled_data[0 * resampled_samples + i]);
+                }
+            }
+            
+            av_frame_unref(decoded_frame);
+        }
+        
+        av_packet_unref(pkt);
+    }
+    
+    av_frame_free(&decoded_frame);
+    av_packet_free(&pkt);
+    avcodec_free_context(&codec_ctx);
+    avformat_close_input(&fmt_ctx);
+    
+    buffer.data = std::move(audio_data);
+    buffer.channels = 2;
+    buffer.sampleRate = target_sample_rate;
+    buffer.samples = buffer.data.size();
+    
     return buffer;
 }
 
@@ -38,7 +232,7 @@ void AudioFile::Save(const std::string& path, const AudioBuffer& buffer) {
     format.bitsPerSample = 32;
     
     drwav wav;
-    if (!drwav_init_file_write(&wav, path.c_str(), &format, NULL)) {
+    if (!drwav_init_file_write(&wav, path.c_str(), &format, nullptr)) {
         throw std::runtime_error("Failed to open file for writing: " + path);
     }
     
@@ -46,6 +240,6 @@ void AudioFile::Save(const std::string& path, const AudioBuffer& buffer) {
     drwav_uninit(&wav);
     
     if (framesWritten != buffer.samples / buffer.channels) {
-         throw std::runtime_error("Failed to write all samples to " + path);
+        throw std::runtime_error("Failed to write all samples to " + path);
     }
-}
+}

+ 184 - 11
src/model.cpp

@@ -66,6 +66,9 @@ void BSRoformer::LoadWeights(const std::string& path) {
     if (architecture_ == "bs_roformer") {
         has_final_norm_ = true;
         transformer_norm_output_ = false;
+    } else if (architecture_ == "bs_roformer_v2") {
+        is_v2_model_ = true;
+        // V2-specific logic can be added here if needed
     } else {
         // mel_band_roformer
         has_final_norm_ = false;
@@ -113,6 +116,29 @@ void BSRoformer::LoadWeights(const std::string& path) {
 
     kv_idx = gguf_find_key(ctx_gguf, (kp + "sample_rate").c_str());
     if (kv_idx >= 0) sample_rate_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+    if (is_v2_model_) {
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "time_transformer_depth").c_str());
+        if (kv_idx >= 0) time_transformer_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+        
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "freq_transformer_depth").c_str());
+        if (kv_idx >= 0) freq_transformer_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "num_key_value_heads").c_str());
+        if (kv_idx >= 0) num_key_value_heads_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "intermediate_size").c_str());
+        if (kv_idx >= 0) intermediate_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "num_input_channels").c_str());
+        if (kv_idx >= 0) num_input_channels_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "band_proj_size").c_str());
+        if (kv_idx >= 0) band_proj_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+        kv_idx = gguf_find_key(ctx_gguf, (kp + "register_token_num").c_str());
+        if (kv_idx >= 0) register_token_num_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+    }
     
     // Inference defaults (optional, fallback to hardcoded values)
     kv_idx = gguf_find_key(ctx_gguf, (kp + "default_chunk_size").c_str());
@@ -229,7 +255,13 @@ std::vector<int> BSRoformer::GetDimInputs() const {
 }
 
 int BSRoformer::GetTotalDimInput() const {
-    if (architecture_ == "bs") {
+    if (is_v2_model_) {
+        int total = 0;
+        for (int i = 0; i < num_bands_; ++i) {
+            total += num_freqs_per_band_[i] * 2;
+        }
+        return total;
+    } else if (architecture_ == "bs_roformer") {
         // BS: All frequencies * stereo * complex
         int n_freq = n_fft_ / 2 + 1;
         return n_freq * 2 * 2;  // freq * stereo * complex
@@ -242,7 +274,7 @@ int BSRoformer::GetTotalDimInput() const {
     return total;
 }
 
-// ========== Graph Building Functions ==========
+// ========== Graph Building Functions ========== 
 
 ggml_tensor* BSRoformer::BuildBandSplitGraph(
     ggml_context* ctx,
@@ -251,6 +283,46 @@ ggml_tensor* BSRoformer::BuildBandSplitGraph(
     int n_frames,
     int batch
 ) {
+    if (is_v2_model_) {
+        // V2 model band split
+        ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, band_proj_size_, num_bands_, n_frames, batch);
+
+        size_t offset_elements = 0;
+        for (int i = 0; i < num_bands_; ++i) {
+            int dim_in = num_freqs_per_band_[i] * 2; // V2 uses 2 instead of 4
+
+            ggml_tensor* band_input = ggml_view_3d(ctx, input,
+                                                   dim_in, n_frames, batch,
+                                                   input->nb[1], input->nb[2],
+                                                   offset_elements * sizeof(float));
+            
+            std::string norm_name = "band_split." + std::to_string(i) + ".norm.weight";
+            ggml_tensor* norm_w = GetWeight(norm_name);
+            if (!norm_w) { std::cerr << "Missing weight: " << norm_name << std::endl; return nullptr; }
+
+            ggml_tensor* normed = ggml_rms_norm(ctx, band_input, 1e-6f);
+            normed = ggml_mul(ctx, normed, norm_w);
+
+            std::string linear_w_name = "band_split." + std::to_string(i) + ".linear.weight";
+            std::string linear_b_name = "band_split." + std::to_string(i) + ".linear.bias";
+            ggml_tensor* linear_w = GetWeight(linear_w_name);
+            ggml_tensor* linear_b = GetWeight(linear_b_name);
+            if (!linear_w || !linear_b) { std::cerr << "Missing weights for band " << i << std::endl; return nullptr; }
+
+            ggml_tensor* projected = ggml_mul_mat(ctx, linear_w, normed);
+            projected = ggml_add(ctx, projected, linear_b);
+
+            ggml_tensor* out_slice = ggml_view_3d(ctx, x,
+                                                  band_proj_size_, n_frames, batch,
+                                                  x->nb[2], x->nb[3],
+                                                  i * x->nb[1]);
+            
+            ggml_build_forward_expand(gf, ggml_cpy(ctx, projected, out_slice));
+            offset_elements += dim_in;
+        }
+        return x;
+    }
+
     // Following test_10_full_model.cpp implementation
     // Input: [total_dim_input, n_frames, batch]
     // Output: [dim, num_bands, n_frames, batch]
@@ -321,6 +393,9 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
     int n_frames,
     int batch
 ) {
+    if (is_v2_model_) {
+        return BuildTransformersGraphV2(ctx, input, gf, pos_time_exp, pos_freq_exp, n_frames, batch);
+    }
     // Following test_10_full_model.cpp implementation
     // Input: [dim, num_bands, n_frames, batch]
     
@@ -341,7 +416,7 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
                 x = ggml_add(ctx, x, s);
             }
         }
-        // ========== TIME TRANSFORMER ==========
+        // ========== TIME TRANSFORMER ========== 
         // Permute: [D, F, T, B] -> [D, T, F, B]
         x = ggml_permute(ctx, x, 0, 2, 1, 3);
         x = ggml_cont(ctx, x);
@@ -493,7 +568,7 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
         x = ggml_permute(ctx, x, 0, 2, 1, 3);
         x = ggml_cont(ctx, x);
         
-        // ========== FREQ TRANSFORMER ==========
+        // ========== FREQ TRANSFORMER ========== 
         int tb = T * B;
         ggml_tensor* x_freq_packed = ggml_reshape_3d(ctx, x, D, F, tb);
         
@@ -558,12 +633,8 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
         ggml_tensor* fV_fa = fV; // fV is contiguous [DIM_HEAD, F, HEADS, tb]
 
         // float scale is already defined in scope (Time Transformer block) or re-define if shadowed loop?
-        // Actually 'scale' was defined inside the Time Transformer loop, so it persists? 
-        // No, Freq Transformer is in the same loop logic? 
-        // Let's check scope. It's in the same 'layer' loop.
-        // But previously I removed the definition line in Time Transformer too? No, I added it back above.
-        // Wait, best to redeclare or rely on scope? 
-        // Time Transformer code block vs Freq Transformer.
+        // Actually 'scale' was defined inside the Time Transformer loop, so it persists?
+        // No, Freq Transformer is in the same loop logic?
         // Let's just use the value. 
         // Re-reading Freq Block:
         // Need to be safe. Redefine 'scale' if needed or ensuring it's available.
@@ -707,7 +778,7 @@ ggml_tensor* BSRoformer::BuildMaskEstimatorGraph(
         total_out_dim += band_out_dims[b];
     }
     
-    ggml_tensor* x = input;  // [D, F, T, B]
+    ggml_tensor* x = input;  // [D, F, T, B] 
     
     // Create mask_output tensor: [total_out_dim, num_stems, n_frames, batch]
     ggml_tensor* mask_output = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, total_out_dim, NUM_STEMS, n_frames, batch);
@@ -800,3 +871,105 @@ ggml_tensor* BSRoformer::BuildMaskEstimatorGraph(
     
     return mask_check;
 }
+
+ggml_tensor* BSRoformer::BuildTransformersGraphV2(
+    ggml_context* ctx,
+    ggml_tensor* input,
+    ggml_cgraph* gf,
+    ggml_tensor* pos_time_exp,
+    ggml_tensor* pos_freq_exp,
+    int n_frames,
+    int batch
+) {
+    ggml_tensor* x = input;
+
+    for (int layer = 0; layer < depth_; ++layer) {
+        // Time Transformer
+        for (int time_layer = 0; time_layer < time_transformer_depth_; ++time_layer) {
+            std::string time_prefix = "blk." + std::to_string(layer) + ".time_attn." + std::to_string(time_layer);
+            ggml_tensor* x_norm = ggml_rms_norm(ctx, x, 1e-6f);
+            ggml_tensor* t_attn_norm_w = GetWeight(time_prefix + ".norm.weight");
+            x_norm = ggml_mul(ctx, x_norm, t_attn_norm_w);
+
+            ggml_tensor* t_qkv_w = GetWeight(time_prefix + ".qkv.weight");
+            ggml_tensor* qkv_out = ggml_mul_mat(ctx, t_qkv_w, x_norm);
+
+            // Split Q, K, V
+            ggml_tensor* Q = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], 0);
+            ggml_tensor* K = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], dim_ * sizeof(float));
+            ggml_tensor* V = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], dim_ * 2 * sizeof(float));
+
+            // RoPE
+            Q = ggml_rope_ext(ctx, Q, pos_time_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
+            K = ggml_rope_ext(ctx, K, pos_time_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
+
+            // Attention
+            ggml_tensor* attn = ggml_flash_attn_ext(ctx, Q, K, V, nullptr, 1.0f / sqrtf(dim_head_), 0.0f, 0.0f);
+            
+            // Output projection
+            ggml_tensor* t_out_w = GetWeight(time_prefix + ".out.weight");
+            attn = ggml_mul_mat(ctx, t_out_w, attn);
+
+            x = ggml_add(ctx, x, attn);
+
+            // MLP
+            ggml_tensor* x_mlp = ggml_rms_norm(ctx, x, 1e-6f);
+            ggml_tensor* mlp_norm_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".norm.weight");
+            x_mlp = ggml_mul(ctx, x_mlp, mlp_norm_w);
+
+            ggml_tensor* mlp_in_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".in.weight");
+            x_mlp = ggml_mul_mat(ctx, mlp_in_w, x_mlp);
+            x_mlp = ggml_gelu(ctx, x_mlp);
+
+            ggml_tensor* mlp_out_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".out.weight");
+            x_mlp = ggml_mul_mat(ctx, mlp_out_w, x_mlp);
+
+            x = ggml_add(ctx, x, x_mlp);
+        }
+
+        // Freq Transformer
+        for (int freq_layer = 0; freq_layer < freq_transformer_depth_; ++freq_layer) {
+            std::string freq_prefix = "blk." + std::to_string(layer) + ".freq_attn." + std::to_string(freq_layer);
+            ggml_tensor* x_norm = ggml_rms_norm(ctx, x, 1e-6f);
+            ggml_tensor* f_attn_norm_w = GetWeight(freq_prefix + ".norm.weight");
+            x_norm = ggml_mul(ctx, x_norm, f_attn_norm_w);
+
+            ggml_tensor* f_qkv_w = GetWeight(freq_prefix + ".qkv.weight");
+            ggml_tensor* qkv_out = ggml_mul_mat(ctx, f_qkv_w, x_norm);
+
+            // Split Q, K, V
+            ggml_tensor* Q = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], 0);
+            ggml_tensor* K = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], dim_ * sizeof(float));
+            ggml_tensor* V = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], dim_ * 2 * sizeof(float));
+
+            // RoPE
+            Q = ggml_rope_ext(ctx, Q, pos_freq_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
+            K = ggml_rope_ext(ctx, K, pos_freq_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
+
+            // Attention
+            ggml_tensor* attn = ggml_flash_attn_ext(ctx, Q, K, V, nullptr, 1.0f / sqrtf(dim_head_), 0.0f, 0.0f);
+            
+            // Output projection
+            ggml_tensor* f_out_w = GetWeight(freq_prefix + ".out.weight");
+            attn = ggml_mul_mat(ctx, f_out_w, attn);
+
+            x = ggml_add(ctx, x, attn);
+
+            // MLP
+            ggml_tensor* x_mlp = ggml_rms_norm(ctx, x, 1e-6f);
+            ggml_tensor* mlp_norm_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".norm.weight");
+            x_mlp = ggml_mul(ctx, x_mlp, mlp_norm_w);
+
+            ggml_tensor* mlp_in_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".in.weight");
+            x_mlp = ggml_mul_mat(ctx, mlp_in_w, x_mlp);
+            x_mlp = ggml_gelu(ctx, x_mlp);
+
+            ggml_tensor* mlp_out_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".out.weight");
+            x_mlp = ggml_mul_mat(ctx, mlp_out_w, x_mlp);
+
+            x = ggml_add(ctx, x, x_mlp);
+        }
+    }
+
+    return x;
+}

+ 22 - 0
src/model.h

@@ -58,6 +58,8 @@ public:
     const std::string& GetArchitecture() const { return architecture_; }
     bool HasFinalNorm() const { return has_final_norm_; }
     bool GetTransformerNormOutput() const { return transformer_norm_output_; }
+
+    bool IsV2Model() const { return is_v2_model_; }
     
     // Inference defaults (from GGUF, can be overridden at runtime)
     int GetDefaultChunkSize() const { return default_chunk_size_; }
@@ -158,6 +160,16 @@ private:
     int mlp_num_layers_ = 3;                 // Detected from weights (BS=2 for depth=2)
     int sample_rate_ = 44100;
 
+    // V2 Params
+    bool is_v2_model_ = false;
+    int time_transformer_depth_ = 1;
+    int freq_transformer_depth_ = 1;
+    int num_key_value_heads_ = 4;
+    int intermediate_size_ = 1152;
+    int num_input_channels_ = 2;
+    int band_proj_size_ = 256;
+    int register_token_num_ = 4;
+
     // Inference defaults
     int default_chunk_size_ = 352800;
     int default_num_overlap_ = 2;
@@ -169,4 +181,14 @@ private:
     
     // Helper to load GGUF
     void LoadWeights(const std::string& path);
+
+    ggml_tensor* BuildTransformersGraphV2(
+        ggml_context* ctx,
+        ggml_tensor* input,
+        ggml_cgraph* gf,
+        ggml_tensor* pos_time_exp,
+        ggml_tensor* pos_freq_exp,
+        int n_frames,
+        int batch = 1
+    );
 };