Sfoglia il codice sorgente

perf(inference): cache buffers to avoid repeated allocations

Cache position data vectors in inference and use thread-local buffers for
FFT operations to prevent heap contention during parallel processing. This
eliminates repeated memory allocations in hot paths, improving performance
for multi-threaded workloads.
沉默の金 4 mesi fa
parent
commit
993ad971b9
6 ha cambiato i file con 129 aggiunte e 75 eliminazioni
  1. 2 2
      README.md
  2. 2 2
      README.zh.md
  3. 4 0
      include/bs_roformer/inference.h
  4. 21 7
      src/inference.cpp
  5. 5 4
      src/model.cpp
  6. 95 60
      src/stft.h

+ 2 - 2
README.md

@@ -2,11 +2,11 @@
 
 [中文](README.zh.md) | English
 
-High-performance C++ inference implementation for the Mel-Band-Roformer audio source separation model.
+High-performance C++ inference implementation for the **BS Roformer** and **Mel-Band-Roformer** audio source separation model.
 
 ## 📖 Introduction
 
-This project is a pure C++ inference engine for the **Mel-Band-Roformer** and **BS Roformer** audio source separation models, built on the [GGML](https://github.com/ggerganov/ggml) tensor library. It primarily used for extracting vocals or accompaniment from music.
+This project is a pure C++ inference engine for the **BS Roformer** and **Mel-Band-Roformer** audio source separation models, built on the [GGML](https://github.com/ggerganov/ggml) tensor library. It primarily used for extracting vocals or accompaniment from music.
 
 ### ✨ Key Features
 

+ 2 - 2
README.zh.md

@@ -2,11 +2,11 @@
 
 中文 | [English](README.md)
 
-Mel-Band-Roformer 音频源分离模型的高性能 C++ 推理实现。
+**BS Roformer** 和 **Mel-Band-Roformer** 音频源分离模型的高性能 C++ 推理实现。
 
 ## 📖 简介
 
-本项目是 **Mel-Band-Roformer** 和 **BS Roformer** 音频源分离模型的纯 C++ 推理引擎,基于 [GGML](https://github.com/ggerganov/ggml) 张量库构建。主要用于从音乐中提取人声或伴奏。
+本项目是 **BS Roformer** 和 **Mel-Band-Roformer** 音频源分离模型的纯 C++ 推理引擎,基于 [GGML](https://github.com/ggerganov/ggml) 张量库构建。主要用于从音乐中提取人声或伴奏。
 
 ### ✨ 主要特性
 

+ 4 - 0
include/bs_roformer/inference.h

@@ -62,6 +62,10 @@ private:
     struct ggml_tensor* pos_freq_ = nullptr;
     struct ggml_tensor* mask_out_tensor_ = nullptr;
 
+    // Cached Host Data (to avoid reallocation)
+    std::vector<int32_t> pos_time_data_;
+    std::vector<int32_t> pos_freq_data_;
+
     // Current config state
     int cached_n_frames_ = -1;
 

+ 21 - 7
src/inference.cpp

@@ -324,22 +324,36 @@ void Inference::RunInference(std::shared_ptr<ChunkState> state) {
     int n_frames = state->n_frames;
 
     // Prepare position data
-    // TODO: Cache these to avoid allocation every frame if size is constant
-    std::vector<int32_t> pos_time_data(n_frames * n_bands);
-    for(int i=0; i < n_frames * n_bands; ++i) pos_time_data[i] = i % n_frames;
+    // Use cached vectors to avoid allocation
+    int required_time_size = n_frames * n_bands;
+    if (pos_time_data_.size() != required_time_size) {
+        pos_time_data_.resize(required_time_size);
+        for(int i=0; i < required_time_size; ++i) pos_time_data_[i] = i % n_frames;
+    }
     
-    std::vector<int32_t> pos_freq_data(n_bands * n_frames);
-    for(int i=0; i < n_bands * n_frames; ++i) pos_freq_data[i] = i % n_bands;
+    int required_freq_size = n_bands * n_frames;
+    // Note: pos_freq logic (i % n_bands) depends on n_bands (constant) and total size.
+    // If n_frames changes, size changes, and values might depend on n_frames?
+    // Wait, pos_freq_data[i] = i % n_bands. 
+    // This is valid regardless of n_frames as long as size is correct.
+    // But we should regenerate if size changes.
+    if (pos_freq_data_.size() != required_freq_size) {
+        pos_freq_data_.resize(required_freq_size);
+        for(int i=0; i < required_freq_size; ++i) pos_freq_data_[i] = i % n_bands;
+    }
 
     // 4. Host -> Device
     ggml_backend_tensor_set(input_tensor_, state->stft_flattened.data(), 0, ggml_nbytes(input_tensor_));
-    ggml_backend_tensor_set(pos_time_, pos_time_data.data(), 0, ggml_nbytes(pos_time_));
-    ggml_backend_tensor_set(pos_freq_, pos_freq_data.data(), 0, ggml_nbytes(pos_freq_));
+    ggml_backend_tensor_set(pos_time_, pos_time_data_.data(), 0, ggml_nbytes(pos_time_));
+    ggml_backend_tensor_set(pos_freq_, pos_freq_data_.data(), 0, ggml_nbytes(pos_freq_));
 
     // 5. Compute
     ggml_backend_graph_compute(model_->GetBackend(), gf_);
 
     // 6. Device -> Host
+    // Avoid reallocation if size roughly matches? 
+    // ggml_nelements(mask_out_tensor_) is fixed for a given n_frames.
+    // state->mask_output is a vector. resize handles it (no op if same size).
     state->mask_output.resize(ggml_nelements(mask_out_tensor_));
     ggml_backend_tensor_get(mask_out_tensor_, state->mask_output.data(), 0, ggml_nbytes(mask_out_tensor_));
 }

+ 5 - 4
src/model.cpp

@@ -58,16 +58,16 @@ void BSRoformer::LoadWeights(const std::string& path) {
     
     // Normalization for legacy models (if any) or simplified internal handling
     if (architecture_ == "bs") architecture_ = "bs_roformer";
-    if (architecture_ == "mel_band") architecture_ = "bs_roformer";
+    if (architecture_ == "mel_band") architecture_ = "mel_band_roformer";
 
-    std::string kp = architecture_ + "."; // key prefix, e.g. "bs_roformer." or "bs_roformer."
+    std::string kp = architecture_ + "."; // key prefix, e.g. "bs_roformer." or "mel_band_roformer."
 
     // Set internal flags based on architecture
     if (architecture_ == "bs_roformer") {
         has_final_norm_ = true;
         transformer_norm_output_ = false;
     } else {
-        // bs_roformer
+        // mel_band_roformer
         has_final_norm_ = false;
         transformer_norm_output_ = true;
     }
@@ -201,7 +201,8 @@ void BSRoformer::LoadWeights(const std::string& path) {
     // Dynamic MLP detection
     // Try to find mask_est.0.freq.0.mlp.{N}.weight
     mlp_num_layers_ = 0;
-    for (int idx = 0; idx <= 20; idx += 2) {  // Check indices 0, 2, 4... up to 10 layers
+    const int MAX_MLP_LAYERS = 20;
+    for (int idx = 0; idx <= MAX_MLP_LAYERS; idx += 2) {  // Check indices 0, 2, 4... up to MAX
         std::string probe = "mask_est.0.freq.0.mlp." + std::to_string(idx) + ".weight";
         if (GetWeight(probe) != nullptr) {
             mlp_num_layers_++;

+ 95 - 60
src/stft.h

@@ -18,6 +18,10 @@
 #include <cstring>
 #include <algorithm> // for std::swap
 
+#ifdef USE_OPENMP
+#include <omp.h>
+#endif
+
 #ifndef M_PI
 #define M_PI 3.14159265358979323846
 #endif
@@ -105,21 +109,34 @@ inline void fft_radix2(Complex* data, int n, bool inverse = false) {
  * @param input Real input array of size n
  * @param output Complex output array of size n/2+1
  * @param n Size of input (must be power of 2)
+ * @param buffer Temporary buffer of size n (optional, handled internally if null)
  */
-inline void rfft(const float* input, Complex* output, int n) {
+inline void rfft(const float* input, Complex* output, int n, std::vector<Complex>* buffer_ptr = nullptr) {
     // Copy to complex buffer
-    std::vector<Complex> buffer(n);
-    for (int i = 0; i < n; ++i) {
-        buffer[i] = Complex(input[i], 0.0f);
-    }
-    
-    // Compute full FFT
-    fft_radix2(buffer.data(), n, false);
-    
-    // Extract first n/2+1 coefficients (one-sided)
-    int n_out = n / 2 + 1;
-    for (int i = 0; i < n_out; ++i) {
-        output[i] = buffer[i];
+    // Use provided buffer to avoid allocation
+    if (buffer_ptr) {
+        if (buffer_ptr->size() < static_cast<size_t>(n)) buffer_ptr->resize(n);
+        for (int i = 0; i < n; ++i) {
+            (*buffer_ptr)[i] = Complex(input[i], 0.0f);
+        }
+        fft_radix2(buffer_ptr->data(), n, false);
+        
+        int n_out = n / 2 + 1;
+        for (int i = 0; i < n_out; ++i) {
+            output[i] = (*buffer_ptr)[i];
+        }
+    } else {
+        std::vector<Complex> buffer(n);
+        for (int i = 0; i < n; ++i) {
+            buffer[i] = Complex(input[i], 0.0f);
+        }
+        
+        fft_radix2(buffer.data(), n, false);
+        
+        int n_out = n / 2 + 1;
+        for (int i = 0; i < n_out; ++i) {
+            output[i] = buffer[i];
+        }
     }
 }
 
@@ -128,25 +145,37 @@ inline void rfft(const float* input, Complex* output, int n) {
  * @param input Complex input array of size n/2+1
  * @param output Real output array of size n
  * @param n_out Size of output (must be power of 2)
+ * @param buffer Temporary buffer of size n_out (optional)
  */
-inline void irfft(const Complex* input, float* output, int n_out) {
+inline void irfft(const Complex* input, float* output, int n_out, std::vector<Complex>* buffer_ptr = nullptr) {
     int n_freq = n_out / 2 + 1;
     
-    // Reconstruct full spectrum (conjugate symmetry)
-    std::vector<Complex> buffer(n_out);
-    for (int i = 0; i < n_freq; ++i) {
-        buffer[i] = input[i];
-    }
-    for (int i = n_freq; i < n_out; ++i) {
-        buffer[i] = std::conj(buffer[n_out - i]);
-    }
-    
-    // Compute inverse FFT
-    fft_radix2(buffer.data(), n_out, true);
-    
-    // Extract real part
-    for (int i = 0; i < n_out; ++i) {
-        output[i] = buffer[i].real();
+    if (buffer_ptr) {
+        if (buffer_ptr->size() < static_cast<size_t>(n_out)) buffer_ptr->resize(n_out);
+        for (int i = 0; i < n_freq; ++i) {
+            (*buffer_ptr)[i] = input[i];
+        }
+         for (int i = n_freq; i < n_out; ++i) {
+            (*buffer_ptr)[i] = std::conj((*buffer_ptr)[n_out - i]);
+        }
+        fft_radix2(buffer_ptr->data(), n_out, true);
+        for (int i = 0; i < n_out; ++i) {
+            output[i] = (*buffer_ptr)[i].real();
+        }
+    } else {
+        std::vector<Complex> buffer(n_out);
+        for (int i = 0; i < n_freq; ++i) {
+            buffer[i] = input[i];
+        }
+        for (int i = n_freq; i < n_out; ++i) {
+            buffer[i] = std::conj(buffer[n_out - i]);
+        }
+        
+        fft_radix2(buffer.data(), n_out, true);
+        
+        for (int i = 0; i < n_out; ++i) {
+            output[i] = buffer[i].real();
+        }
     }
 }
 
@@ -224,42 +253,42 @@ inline void compute_stft(
         std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
     }
     
-    // Buffers
-    std::vector<float> frame(n_fft);
-    std::vector<Complex> fft_out(n_freq);
+    // Pre-allocate thread-local buffers
+    int max_threads = 1;
+    #ifdef USE_OPENMP
+    max_threads = omp_get_max_threads();
+    #endif
+    
+    std::vector<std::vector<float>> thread_frames(max_threads, std::vector<float>(n_fft));
+    std::vector<std::vector<Complex>> thread_fft_outs(max_threads, std::vector<Complex>(n_freq));
+    std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
     
     // Process each frame
     #ifdef USE_OPENMP
     #pragma omp parallel for
     #endif
     for (int f = 0; f < n_frames; ++f) {
-        int start = f * hop_length;
+        int tid = 0;
+        #ifdef USE_OPENMP
+        tid = omp_get_thread_num();
+        #endif
         
-        // Extract and window frame
-        // Need private buffer for frame and fft_out if logical threads share memory?
-        // Wait, std::vector inside loop is local to block, so essentially thread-private?
-        // YES. Variables declared inside the loop are private to the iteration/thread.
+        std::vector<float>& frame = thread_frames[tid];
         
-        // However, we need to be careful about allocating vectors inside a loop in parallel (heap contention).
-        // It's better to allocate buffers per thread or use raw arrays.
-        // For simplicity and since n_fft is small (2048), stack array or thread_local vector is better.
-        // But std::vector inside parallel for is safe but might allocate.
-        // n_fft=2048 float is 8KB. 
-        std::vector<float> frame(n_fft); // Allocation!
-        std::vector<Complex> fft_out(n_freq);
+        int start = f * hop_length;
         
         for (int i = 0; i < n_fft; ++i) {
             frame[i] = padded[start + i] * window_padded[i];
         }
         
-        // Compute FFT
-        rfft(frame.data(), fft_out.data(), n_fft);
+        // Compute FFT using pre-allocated buffers
+        rfft(frame.data(), thread_fft_outs[tid].data(), n_fft, &thread_fft_buffers[tid]);
         
         // Store in output [n_freq, n_frames, 2] format
         for (int k = 0; k < n_freq; ++k) {
             // Note: Output layout is [Freq, Time, 2]
-            output[(k * n_frames + f) * 2 + 0] = fft_out[k].real();
-            output[(k * n_frames + f) * 2 + 1] = fft_out[k].imag();
+            output[(k * n_frames + f) * 2 + 0] = thread_fft_outs[tid][k].real();
+            output[(k * n_frames + f) * 2 + 1] = thread_fft_outs[tid][k].imag();
         }
     }
 }
@@ -304,24 +333,30 @@ inline void compute_istft(
         std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
     }
     
-    // Overlap-add buffers
-    // This is tricky for parallelization: race condition on y (overlap-add).
-    // We CANNOT parallelize the write to 'y' easily without atomic float add (slow/hard) or reduction.
-    // APPROACH:
-    // 1. Parallel IFFT: Compute all frames' time-domain signals into a large buffer [n_frames, n_fft].
-    // 2. Serial Overlap-Add: Add them up. (Overlap-add is O(N_Frames * N_FFT), same complexity, but memory bound).
-    // Serial part might be fast enough if FFT is the heavy lifter.
-    // FFT is O(N log N). Overlap add is O(N). FFT dominates.
-    
     // Step 1: Compute all IFFTs in parallel
     std::vector<float> frames_time_domain(n_frames * n_fft);
     
+    // Pre-allocate thread-local buffers
+    int max_threads = 1;
+    #ifdef USE_OPENMP
+    max_threads = omp_get_max_threads();
+    #endif
+    
+    std::vector<std::vector<Complex>> thread_fft_ins(max_threads, std::vector<Complex>(n_freq));
+    std::vector<std::vector<float>> thread_frame_outs(max_threads, std::vector<float>(n_fft));
+    std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
+    
     #ifdef USE_OPENMP
     #pragma omp parallel for
     #endif
     for (int f = 0; f < n_frames; ++f) {
-        std::vector<Complex> fft_in(n_freq);
-        std::vector<float> frame_out(n_fft);
+        int tid = 0;
+        #ifdef USE_OPENMP
+        tid = omp_get_thread_num();
+        #endif
+        
+        std::vector<Complex>& fft_in = thread_fft_ins[tid];
+        std::vector<float>& frame_out = thread_frame_outs[tid];
         
         // Extract complex spectrum
         for (int k = 0; k < n_freq; ++k) {
@@ -331,7 +366,7 @@ inline void compute_istft(
         }
         
         // IFFT
-        irfft(fft_in.data(), frame_out.data(), n_fft);
+        irfft(fft_in.data(), frame_out.data(), n_fft, &thread_fft_buffers[tid]);
         
         // Store
         std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));