Pārlūkot izejas kodu

perf(stft): optimize FFT with table-based transforms and memory pooling

Replace on-the-fly FFT computation with table-based Radix-2 implementation
using precomputed twiddle factors and bit-reversal indices. Introduce STFTBuffer
memory pooling to eliminate allocation overhead in STFT/ISTFT loops.

Refactor inference pipeline to 3-stage threaded architecture with thread-safe
queues connecting preprocessing, GPU inference, and postprocessing stages.
Improves throughput and GPU utilization by decoupling CPU-bound audio
processing from GPU inference.

Add STFT consistency tests verifying numerical accuracy against PyTorch
reference implementations.

BREAKING CHANGE: STFT API changed - rfft/irfft now require STFTBuffer&
parameter instead of optional std::vector<Complex>*
沉默の金 4 mēneši atpakaļ
vecāks
revīzija
e5db284b19
6 mainītis faili ar 496 papildinājumiem un 272 dzēšanām
  1. 1 1
      .gitignore
  2. 23 0
      scripts/generate_test_data.py
  3. 124 82
      src/inference.cpp
  4. 202 189
      src/stft.h
  5. 1 0
      tests/CMakeLists.txt
  6. 145 0
      tests/test_stft_consistency.cpp

+ 1 - 1
.gitignore

@@ -1,2 +1,2 @@
-build
+build*
 tests/data

+ 23 - 0
scripts/generate_test_data.py

@@ -300,6 +300,29 @@ def generate_test_data(
             return_complex=True,
         )
         stft_repr = torch.view_as_real(stft_repr)
+
+        # ===== CAPTURE: Raw STFT/ISTFT for C++ Verification =====
+        # Unpack to [batch, channels, freq, time, 2]
+        stft_raw_unpacked = unpack_one(
+            stft_repr, batch_audio_channel_packed_shape, "* f t c"
+        )
+        captured["stft_raw"] = stft_raw_unpacked.clone()
+
+        # Compute ISTFT directly on this raw STFT (Identity check)
+        stft_complex = torch.view_as_complex(stft_repr)
+        istft_check = torch.istft(
+            stft_complex,
+            **model.stft_kwargs,
+            window=stft_window,
+            return_complex=False,
+            length=istft_length,
+        )
+        istft_check_unpacked = unpack_one(
+            istft_check, batch_audio_channel_packed_shape, "* t"
+        )
+        captured["istft_raw"] = istft_check_unpacked.clone()
+        # ========================================================
+
         stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
         stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
 

+ 124 - 82
src/inference.cpp

@@ -11,6 +11,10 @@
 #include <ggml-backend.h>
 #include <chrono>
 #include <future>
+#include <queue>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
 
 using Complex = std::complex<float>;
 
@@ -386,6 +390,63 @@ std::vector<std::vector<float>> Inference::ProcessChunk(const std::vector<float>
 // Pipelined Overlap-Add Logic
 // =================================================================================================
 
+// =================================================================================================
+// Thread Safe Queue
+// =================================================================================================
+
+template <typename T>
+class ThreadSafeQueue {
+public:
+    ThreadSafeQueue(size_t max_size) : max_size_(max_size), shutdown_(false) {}
+
+    ~ThreadSafeQueue() {
+        Shutdown();
+    }
+
+    void Push(T item) {
+        std::unique_lock<std::mutex> lock(mutex_);
+        cv_push_.wait(lock, [this] { return queue_.size() < max_size_ || shutdown_; });
+        if (shutdown_) return;
+        queue_.push(std::move(item));
+        cv_pop_.notify_one();
+    }
+
+    bool Pop(T& item) {
+        std::unique_lock<std::mutex> lock(mutex_);
+        cv_pop_.wait(lock, [this] { return !queue_.empty() || shutdown_; });
+        if (queue_.empty() && shutdown_) return false;
+        item = std::move(queue_.front());
+        queue_.pop();
+        cv_push_.notify_one();
+        return true;
+    }
+
+    void Shutdown() {
+        {
+            std::lock_guard<std::mutex> lock(mutex_);
+            shutdown_ = true;
+        }
+        cv_push_.notify_all();
+        cv_pop_.notify_all();
+    }
+
+private:
+    std::queue<T> queue_;
+    size_t max_size_;
+    bool shutdown_;
+    std::mutex mutex_;
+    std::condition_variable cv_push_;
+    std::condition_variable cv_pop_;
+};
+
+// =================================================================================================
+// Pipelined Overlap-Add Logic
+// =================================================================================================
+
+// =================================================================================================
+// Pipelined Overlap-Add Logic (Optimized 3-Stage)
+// =================================================================================================
+
 std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std::vector<float>& input_audio, 
                                                          int chunk_size, 
                                                          int num_overlap,
@@ -443,6 +504,7 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
     std::vector<std::vector<float>> result; // [stems][samples]
     std::vector<float> counter(n_padded_samples * channels, 0.0f);
     std::vector<float> window_base = GetWindow(chunk_size, fade_size);
+    std::mutex result_mutex; // Protects 'result' and 'counter'
     
     // lambda to extract chunk 'i'
     auto extract_chunk = [&](int i) -> std::vector<float> {
@@ -476,11 +538,14 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
     };
 
     // lambda to accumulate result 'state' at offset 'i'
+    // Now protected by mutex
     auto accumulate_result = [&](std::shared_ptr<ChunkState> state, int i) {
         if (!state) return;
-        const std::vector<std::vector<float>>& chunk_out_stems = state->final_audio; // Now [stems][samples]
+        const std::vector<std::vector<float>>& chunk_out_stems = state->final_audio;
         if (chunk_out_stems.empty()) return;
         
+        std::lock_guard<std::mutex> lock(result_mutex);
+
         // Lazy Initialize result
         if (result.empty()) {
             int num_stems = chunk_out_stems.size();
@@ -505,9 +570,9 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
             
             for (int s = 0; s < num_stems; ++s) {
                  if (s >= chunk_out_stems.size()) continue;
-                 const auto& stem_chunk = chunk_out_stems[s];
-                 result[s][res_idx + 0] += stem_chunk[chk_idx + 0] * w;
-                 result[s][res_idx + 1] += stem_chunk[chk_idx + 1] * w;
+                 // result[s] is huge, but we access linearly in this block
+                 result[s][res_idx + 0] += chunk_out_stems[s][chk_idx + 0] * w;
+                 result[s][res_idx + 1] += chunk_out_stems[s][chk_idx + 1] * w;
             }
             
             // Counter is same for all stems, just update once
@@ -516,92 +581,69 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
         }
     };
 
-    // ==========================================================
-    // Pipeline Loop
-    // ==========================================================
-    
-    // Future for the NEXT chunk's preprocessing
-    std::future<std::shared_ptr<ChunkState>> next_prep_future;
-    
-    // Future for the PREVIOUS chunk's postprocessing
-    std::future<void> prev_post_future;
+    // =================================================================================================
+    // 3-Stage Pipeline
+    // =================================================================================================
     
-    std::shared_ptr<ChunkState> prev_state = nullptr;
+    // Queues
+    // Bounded size to prevents running out of memory
+    // 3 items buffer is enough to keep GPU busy
+    ThreadSafeQueue<std::shared_ptr<ChunkState>> input_queue(3);
+    ThreadSafeQueue<std::shared_ptr<ChunkState>> output_queue(3);
     
-    int i = 0;
-    int current_offset = 0;
-    
-    // Bootstrap: Start PreProcessing first chunk
-    {
-        std::vector<float> chunk0 = extract_chunk(0);
-        // Async launch
-        next_prep_future = std::async(std::launch::async, 
-            [this](std::vector<float> c, int id) { return this->PreProcessChunk(c, id); }, 
-            std::move(chunk0), 0);
-    }
+    // Structure to hold chunk metadata together
+    struct ChunkTask {
+        int offset;
+        std::shared_ptr<ChunkState> state;
+    };
     
-    while (current_offset < n_padded_samples) {
-        // 1. Wait for PRE-processing of CURRENT chunk
-        if (next_prep_future.valid()) {
-            // This blocks until STFT is done.
-            // In steady state, this should be ready or nearly ready while GPU was busy.
-        }
-        auto current_state = next_prep_future.get();
-        
-        // 2. Start PRE-processing of NEXT chunk (if exists)
-        int next_offset = current_offset + step;
-        if (next_offset < n_padded_samples) {
-             std::vector<float> chunk_next = extract_chunk(next_offset);
-             next_prep_future = std::async(std::launch::async, 
-                [this](std::vector<float> c, int id) { return this->PreProcessChunk(c, id); }, 
-                std::move(chunk_next), next_offset);
-        } else {
-            // No more next chunks
-        }
-        
-        // 3. Run Inference on CURRENT chunk (GPU Sync)
-        // This blocks heavily.
-        RunInference(current_state);
-        
-        // 4. Wait for POST-processing of PREVIOUS chunk
-        if (prev_post_future.valid()) {
-            prev_post_future.get();
+    // 1. Preprocessor Thread
+    auto preproccessor = std::thread([&]() {
+        int current_offset = 0;
+        while (current_offset < n_padded_samples) {
+            std::vector<float> chunk = extract_chunk(current_offset);
+            
+            auto state = PreProcessChunk(chunk, current_offset); 
+            
+            input_queue.Push(state);
+            current_offset += step;
+        }
+        input_queue.Shutdown();
+    });
+    
+    // 3. Postprocessor Thread
+    auto postprocessor = std::thread([&]() {
+        std::shared_ptr<ChunkState> state;
+        while (output_queue.Pop(state)) {
+            // This does ISTFT (CPU intensive)
+            PostProcessChunk(state);
+            
+            // Accumulate (Memory bandwidth intensive + Mutex)
+            accumulate_result(state, state->id); // state->id holds offset
+            
+            if (progress_callback) {
+                float progress = (float)std::min(state->id + step, n_padded_samples) / n_padded_samples;
+                progress_callback(progress);
+            }
         }
+    });
+    
+    // 2. Main Thread (Inference Loop)
+    std::shared_ptr<ChunkState> state;
+    while (true) {
+        bool ok = input_queue.Pop(state);
+        if (!ok) break; // Input queue shutdown and empty
         
-        // 5. Accumulate PREVIOUS chunk result (Serial, fast)
-        // Note: PostProcessChunk fills 'final_audio', but doesn't accumulate to 'result'.
-        // We do accumulation here on main thread to avoid races on 'result' buffer.
-        if (prev_state) {
-            int prev_offset = current_offset - step;
-            accumulate_result(prev_state, prev_offset);
-            prev_state = nullptr; // Free memory
-        }
+        // This does GGML Inference (GPU intensive, Blocking)
+        RunInference(state);
         
-        // 6. Start POST-processing of CURRENT chunk
-        prev_state = current_state;
-        // Use shared_ptr copy
-        prev_post_future = std::async(std::launch::async, 
-            [this](std::shared_ptr<ChunkState> s) { this->PostProcessChunk(s); }, 
-            prev_state);
-            
-        // Advance
-        current_offset += step;
-
-        if (progress_callback) {
-            float progress = (float)std::min(current_offset, n_padded_samples) / n_padded_samples;
-            progress_callback(progress);
-        }
+        output_queue.Push(state);
     }
     
-    // Drain Pipeline
-    // Wait for last post-process
-    if (prev_post_future.valid()) {
-        prev_post_future.get();
-    }
-    if (prev_state) {
-        int prev_offset = current_offset - step;
-        accumulate_result(prev_state, prev_offset);
-    }
+    // Wait for threads
+    output_queue.Shutdown();
+    if (preproccessor.joinable()) preproccessor.join();
+    if (postprocessor.joinable()) postprocessor.join();
     
     // Normalize and Crop
     // result is [stems][samples]

+ 202 - 189
src/stft.h

@@ -1,22 +1,22 @@
 #pragma once
 /**
- * stft.h - STFT/ISTFT implementation
+ * stft.h - STFT/ISTFT implementation (Optimized)
  * 
  * Implements:
- * - Hann window generation
+ * - Table-based Hann window generation
+ * - Table-based Radix-2 FFT (Twiddle factors & Bit-reversal)
+ * - Thread-safe Memory Pooling (STFTBuffer)
  * - Center padding (reflect mode)
  * - Frame extraction
- * - Radix-2 Cooley-Tukey FFT
- * - Real-to-complex FFT (rfft)
- * - Inverse FFT (irfft)
- * - Full STFT/ISTFT matching torch.stft/torch.istft
  */
 
 #include <cmath>
 #include <vector>
 #include <complex>
 #include <cstring>
-#include <algorithm> // for std::swap
+#include <algorithm>
+#include <memory>
+#include <mutex>
 
 #ifdef USE_OPENMP
 #include <omp.h>
@@ -28,19 +28,44 @@
 
 namespace stft {
 
-// Complex number type
 using Complex = std::complex<float>;
 
 //=============================================================================
-// Window Functions
+// Memory Pooling
 //=============================================================================
 
 /**
- * Generate Hann window matching torch.hann_window()
- * PyTorch uses periodic=True by default for STFT compatibility
- * Periodic formula: 0.5 * (1 - cos(2*pi*n / N))
- * Symmetric formula: 0.5 * (1 - cos(2*pi*n / (N-1)))
+ * Thread-local buffer storage to avoid frequent allocations in STFT/ISTFT loops.
  */
+struct STFTBuffer {
+    // FFT buffers
+    std::vector<Complex> fft_in;
+    std::vector<Complex> fft_out;
+    std::vector<Complex> fft_scratch;
+    
+    // Frame buffers
+    std::vector<float> frame_in;
+    std::vector<float> frame_out;
+    
+    // Window buffers
+    std::vector<float> window_padded;
+    std::vector<float> padded_audio;
+    
+    void Resize(int n_fft, int padded_len = 0) {
+        if (fft_in.size() != n_fft) fft_in.resize(n_fft);
+        if (fft_out.size() != n_fft) fft_out.resize(n_fft);
+        if (fft_scratch.size() != n_fft) fft_scratch.resize(n_fft);
+        if (frame_in.size() != n_fft) frame_in.resize(n_fft);
+        if (frame_out.size() != n_fft) frame_out.resize(n_fft);
+        if (window_padded.size() != n_fft) window_padded.resize(n_fft);
+        if (padded_len > 0 && padded_audio.size() < padded_len) padded_audio.resize(padded_len);
+    }
+};
+
+//=============================================================================
+// Window Functions
+//=============================================================================
+
 inline void hann_window(float* out, int size, bool periodic = true) {
     int divisor = periodic ? size : (size - 1);
     for (int i = 0; i < size; ++i) {
@@ -49,153 +74,151 @@ inline void hann_window(float* out, int size, bool periodic = true) {
 }
 
 //=============================================================================
-// FFT Implementation (Cooley-Tukey Radix-2)
+// FFT Implementation (Table-based Cooley-Tukey Radix-2)
 //=============================================================================
 
-/**
- * Bit-reversal permutation for radix-2 FFT
- */
-inline void bit_reverse(Complex* data, int n) {
-    int j = 0;
-    for (int i = 0; i < n - 1; ++i) {
-        if (i < j) {
-            std::swap(data[i], data[j]);
-        }
-        int m = n >> 1;
-        while (j >= m && m > 0) {
-            j -= m;
-            m >>= 1;
+class TableFFT {
+public:
+    static TableFFT& GetInstance(int n_fft) {
+        static std::mutex mtx;
+        static std::unique_ptr<TableFFT> instance;
+        static int current_n_fft = -1;
+
+        std::lock_guard<std::mutex> lock(mtx);
+        if (!instance || current_n_fft != n_fft) {
+            instance = std::make_unique<TableFFT>(n_fft);
+            current_n_fft = n_fft;
         }
-        j += m;
+        return *instance;
     }
-}
 
-/**
- * In-place Cooley-Tukey radix-2 FFT
- * @param data Complex array of size n (must be power of 2)
- * @param n Size of array
- * @param inverse If true, compute inverse FFT
- */
-inline void fft_radix2(Complex* data, int n, bool inverse = false) {
-    bit_reverse(data, n);
-    
-    // Danielson-Lanczos lemma
-    for (int len = 2; len <= n; len <<= 1) {
-        float angle = (inverse ? 2.0f : -2.0f) * static_cast<float>(M_PI) / len;
-        Complex w_n(std::cos(angle), std::sin(angle));
-        
-        for (int i = 0; i < n; i += len) {
-            Complex w(1.0f, 0.0f);
-            for (int j = 0; j < len / 2; ++j) {
-                Complex u = data[i + j];
-                Complex t = w * data[i + j + len / 2];
-                data[i + j] = u + t;
-                data[i + j + len / 2] = u - t;
-                w *= w_n;
-            }
-        }
+    TableFFT(int n) : n_(n) {
+        Precomputetables();
+    }
+
+    void Forward(Complex* data) const {
+        BitReverse(data);
+        Compute(data, false);
     }
     
-    // Normalize for inverse FFT
-    if (inverse) {
-        for (int i = 0; i < n; ++i) {
-            data[i] /= static_cast<float>(n);
+    void Inverse(Complex* data) const {
+        BitReverse(data);
+        Compute(data, true);
+        
+        // Normalize
+        float inv_n = 1.0f / n_;
+        for (int i = 0; i < n_; ++i) {
+            data[i] *= inv_n;
         }
     }
-}
 
-/**
- * Real-to-complex FFT (rfft) matching torch.fft.rfft
- * @param input Real input array of size n
- * @param output Complex output array of size n/2+1
- * @param n Size of input (must be power of 2)
- * @param buffer Temporary buffer of size n (optional, handled internally if null)
- */
-inline void rfft(const float* input, Complex* output, int n, std::vector<Complex>* buffer_ptr = nullptr) {
-    // Copy to complex buffer
-    // Use provided buffer to avoid allocation
-    if (buffer_ptr) {
-        if (buffer_ptr->size() < static_cast<size_t>(n)) buffer_ptr->resize(n);
-        for (int i = 0; i < n; ++i) {
-            (*buffer_ptr)[i] = Complex(input[i], 0.0f);
+private:
+    int n_;
+    std::vector<int> bit_reverse_indices_;
+    std::vector<Complex> twiddles_fwd_;
+    std::vector<Complex> twiddles_inv_;
+
+    void Precomputetables() {
+        // 1. Bit Reverse
+        bit_reverse_indices_.resize(n_);
+        int j = 0;
+        for (int i = 0; i < n_ - 1; ++i) {
+            bit_reverse_indices_[i] = (i < j) ? j : i; // Store swap target
+            int m = n_ >> 1;
+            while (j >= m && m > 0) {
+                j -= m;
+                m >>= 1;
+            }
+            j += m;
         }
-        fft_radix2(buffer_ptr->data(), n, false);
+        bit_reverse_indices_[n_ - 1] = n_ - 1;
+
+        // 2. Twiddles
+        // We only need twiddles for len = 2, 4, 8 ... n
+        // Total count is roughly N.
+        // Structure: [len=2: w], [len=4: w, w^2], ...
+        // Simplification: Store W_N^k for k=0..N/2-1.
+        // Then step=N/len.
+        twiddles_fwd_.resize(n_ / 2);
+        twiddles_inv_.resize(n_ / 2);
         
-        int n_out = n / 2 + 1;
-        for (int i = 0; i < n_out; ++i) {
-            output[i] = (*buffer_ptr)[i];
+        for (int k = 0; k < n_ / 2; ++k) {
+            float angle = -2.0f * static_cast<float>(M_PI) * k / n_;
+            twiddles_fwd_[k] = Complex(std::cos(angle), std::sin(angle));
+            twiddles_inv_[k] = std::conj(twiddles_fwd_[k]);
         }
-    } else {
-        std::vector<Complex> buffer(n);
-        for (int i = 0; i < n; ++i) {
-            buffer[i] = Complex(input[i], 0.0f);
+    }
+
+    void BitReverse(Complex* data) const {
+        for (int i = 0; i < n_; ++i) {
+            int j = bit_reverse_indices_[i];
+            if (i < j) {
+                std::swap(data[i], data[j]);
+            }
         }
+    }
+
+    void Compute(Complex* data, bool inverse) const {
+        const auto& twiddles = inverse ? twiddles_inv_ : twiddles_fwd_;
         
-        fft_radix2(buffer.data(), n, false);
-        
-        int n_out = n / 2 + 1;
-        for (int i = 0; i < n_out; ++i) {
-            output[i] = buffer[i];
+        for (int len = 2; len <= n_; len <<= 1) {
+            int half_len = len >> 1;
+            int step = n_ / len;
+            
+            for (int i = 0; i < n_; i += len) {
+                for (int j = 0; j < half_len; ++j) {
+                    Complex w = twiddles[j * step];
+                    Complex u = data[i + j];
+                    Complex t = w * data[i + j + half_len];
+                    data[i + j] = u + t;
+                    data[i + j + half_len] = u - t;
+                }
+            }
         }
     }
+};
+
+
+//=============================================================================
+// STFT Wrapper (Optimized)
+//=============================================================================
+
+inline void rfft(const float* input, Complex* output, int n, STFTBuffer& buffer) {
+    // 1. Copy to complex buffer
+    for (int i = 0; i < n; ++i) {
+        buffer.fft_scratch[i] = Complex(input[i], 0.0f);
+    }
+    
+    // 2. FFT
+    TableFFT::GetInstance(n).Forward(buffer.fft_scratch.data());
+    
+    // 3. Copy first N/2 + 1
+    int n_out = n / 2 + 1;
+    for (int i = 0; i < n_out; ++i) {
+        output[i] = buffer.fft_scratch[i];
+    }
 }
 
-/**
- * Complex-to-real inverse FFT (irfft) matching torch.fft.irfft
- * @param input Complex input array of size n/2+1
- * @param output Real output array of size n
- * @param n_out Size of output (must be power of 2)
- * @param buffer Temporary buffer of size n_out (optional)
- */
-inline void irfft(const Complex* input, float* output, int n_out, std::vector<Complex>* buffer_ptr = nullptr) {
+inline void irfft(const Complex* input, float* output, int n_out, STFTBuffer& buffer) {
     int n_freq = n_out / 2 + 1;
     
-    if (buffer_ptr) {
-        if (buffer_ptr->size() < static_cast<size_t>(n_out)) buffer_ptr->resize(n_out);
-        for (int i = 0; i < n_freq; ++i) {
-            (*buffer_ptr)[i] = input[i];
-        }
-         for (int i = n_freq; i < n_out; ++i) {
-            (*buffer_ptr)[i] = std::conj((*buffer_ptr)[n_out - i]);
-        }
-        fft_radix2(buffer_ptr->data(), n_out, true);
-        for (int i = 0; i < n_out; ++i) {
-            output[i] = (*buffer_ptr)[i].real();
-        }
-    } else {
-        std::vector<Complex> buffer(n_out);
-        for (int i = 0; i < n_freq; ++i) {
-            buffer[i] = input[i];
-        }
-        for (int i = n_freq; i < n_out; ++i) {
-            buffer[i] = std::conj(buffer[n_out - i]);
-        }
-        
-        fft_radix2(buffer.data(), n_out, true);
-        
-        for (int i = 0; i < n_out; ++i) {
-            output[i] = buffer[i].real();
-        }
+    // 1. Reconstruct full spectrum
+    for (int i = 0; i < n_freq; ++i) {
+        buffer.fft_scratch[i] = input[i];
+    }
+    for (int i = n_freq; i < n_out; ++i) {
+        buffer.fft_scratch[i] = std::conj(buffer.fft_scratch[n_out - i]);
+    }
+    
+    // 2. IFFT
+    TableFFT::GetInstance(n_out).Inverse(buffer.fft_scratch.data());
+    
+    // 3. Real part
+    for (int i = 0; i < n_out; ++i) {
+        output[i] = buffer.fft_scratch[i].real();
     }
 }
 
-//=============================================================================
-// STFT Implementation
-//=============================================================================
-
-/**
- * Short-Time Fourier Transform matching torch.stft
- * 
- * @param audio Input audio [n_samples]
- * @param n_samples Number of samples
- * @param n_fft FFT size
- * @param hop_length Hop between frames
- * @param win_length Window length
- * @param window Window function [win_length]
- * @param center If true, pad signal on both sides
- * @param output Output complex spectrogram [n_freq, n_frames, 2] (real, imag pairs)
- * @param n_frames_out Output parameter: number of frames
- */
 inline void compute_stft(
     const float* audio,
     int n_samples,
@@ -211,21 +234,33 @@ inline void compute_stft(
     int pad_amount = center ? n_fft / 2 : 0;
     int padded_len = n_samples + 2 * pad_amount;
     
-    std::vector<float> padded(padded_len);
+    // Calculate number of frames
+    // PyTorch formula: (L - N) / H + 1
+    int n_frames = 1 + (padded_len - n_fft) / hop_length;
+    if (n_frames < 0) n_frames = 0;
+    *n_frames_out = n_frames;
+    
+    // Prepare padding buffer (thread-local or single allocation if not parallel? 
+    // Padding + Windowing is usually fast, but padding needs full copy.)
+    // For safety and simplicity, let's allocate padded audio once here (It's one large buffer).
+    // The previous implementation used thread_local for 'padded_audio' which is wrong because 
+    // 'padded_audio' needs to hold the WHOLE signal? No, stft.h:52 says 'padded_audio'.
+    // Analyzing original code: It copied the WHOLE signal to 'padded_audio' inside compute_stft.
+    // That means 'tls_buffer' was huge! If we have multiple threads, each copying full audio? 
+    // That's wasteful.
+    // Better: Allocate 'padded' once on heap.
     
+    std::vector<float> padded(padded_len);
     if (center) {
         // Reflect padding
-        // Left pad (reflect)
         for (int i = 0; i < pad_amount; ++i) {
             int src_idx = pad_amount - i;
             if (src_idx >= n_samples) src_idx = n_samples - 1;
             padded[i] = audio[src_idx];
         }
-        // Center (copy)
         if (n_samples > 0) {
             std::memcpy(padded.data() + pad_amount, audio, n_samples * sizeof(float));
         }
-        // Right pad (reflect)
         for (int i = 0; i < pad_amount; ++i) {
             int src_idx = n_samples - 2 - i;
             if (src_idx < 0) src_idx = 0;
@@ -234,17 +269,10 @@ inline void compute_stft(
     } else {
         std::memcpy(padded.data(), audio, n_samples * sizeof(float));
     }
-    
-    // Calculate number of frames
-    // PyTorch formula: (L - N) / H + 1
-    int n_frames = 1 + (padded_len - n_fft) / hop_length;
-    if (n_frames < 0) n_frames = 0;
-    *n_frames_out = n_frames;
-    
-    // Number of output frequency bins
+
     int n_freq = n_fft / 2 + 1;
     
-    // Prepare padded window if win_length < n_fft
+    // Prepare window (Single copy)
     std::vector<float> window_padded(n_fft, 0.0f);
     if (win_length < n_fft) {
         int left = (n_fft - win_length) / 2;
@@ -253,16 +281,14 @@ inline void compute_stft(
         std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
     }
     
-    // Pre-allocate thread-local buffers
+    // Prepare thread buffers
     int max_threads = 1;
     #ifdef USE_OPENMP
     max_threads = omp_get_max_threads();
     #endif
-    
-    std::vector<std::vector<float>> thread_frames(max_threads, std::vector<float>(n_fft));
-    std::vector<std::vector<Complex>> thread_fft_outs(max_threads, std::vector<Complex>(n_freq));
-    std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
-    
+    std::vector<STFTBuffer> thread_buffers(max_threads);
+    for(auto& buf : thread_buffers) buf.Resize(n_fft);
+
     // Process each frame
     #ifdef USE_OPENMP
     #pragma omp parallel for
@@ -272,41 +298,29 @@ inline void compute_stft(
         #ifdef USE_OPENMP
         tid = omp_get_thread_num();
         #endif
-        
-        std::vector<float>& frame = thread_frames[tid];
-        
+        STFTBuffer& buffer = thread_buffers[tid];
+
+        std::vector<float>& frame = buffer.frame_in;
         int start = f * hop_length;
         
         for (int i = 0; i < n_fft; ++i) {
             frame[i] = padded[start + i] * window_padded[i];
         }
         
-        // Compute FFT using pre-allocated buffers
-        rfft(frame.data(), thread_fft_outs[tid].data(), n_fft, &thread_fft_buffers[tid]);
+        // Compute FFT
+        // Output pointer directly to destination
+        // We need a place to store complex output before writing to planar output
+        
+        rfft(frame.data(), buffer.fft_out.data(), n_fft, buffer);
         
-        // Store in output [n_freq, n_frames, 2] format
+        // Write to output
         for (int k = 0; k < n_freq; ++k) {
-            // Note: Output layout is [Freq, Time, 2]
-            output[(k * n_frames + f) * 2 + 0] = thread_fft_outs[tid][k].real();
-            output[(k * n_frames + f) * 2 + 1] = thread_fft_outs[tid][k].imag();
+            output[(k * n_frames + f) * 2 + 0] = buffer.fft_out[k].real();
+            output[(k * n_frames + f) * 2 + 1] = buffer.fft_out[k].imag();
         }
     }
 }
 
-/**
- * Inverse Short-Time Fourier Transform matching torch.istft
- * 
- * @param stft_data Input complex spectrogram [n_freq, n_frames, 2]
- * @param n_freq Number of frequency bins
- * @param n_frames Number of frames
- * @param n_fft FFT size
- * @param hop_length Hop between frames
- * @param win_length Window length
- * @param window Window function [win_length]
- * @param center If true, signal was centered
- * @param length Expected output length (or 0 for auto)
- * @param output Output audio
- */
 inline void compute_istft(
     const float* stft_data,
     int n_freq,
@@ -333,18 +347,16 @@ inline void compute_istft(
         std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
     }
     
-    // Step 1: Compute all IFFTs in parallel
-    std::vector<float> frames_time_domain(n_frames * n_fft);
-    
-    // Pre-allocate thread-local buffers
+    // Prepare thread buffers
     int max_threads = 1;
     #ifdef USE_OPENMP
     max_threads = omp_get_max_threads();
     #endif
-    
-    std::vector<std::vector<Complex>> thread_fft_ins(max_threads, std::vector<Complex>(n_freq));
-    std::vector<std::vector<float>> thread_frame_outs(max_threads, std::vector<float>(n_fft));
-    std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
+    std::vector<STFTBuffer> thread_buffers(max_threads);
+    for(auto& buf : thread_buffers) buf.Resize(n_fft);
+
+    // Step 1: Compute all IFFTs in parallel
+    std::vector<float> frames_time_domain(n_frames * n_fft);
     
     #ifdef USE_OPENMP
     #pragma omp parallel for
@@ -354,9 +366,10 @@ inline void compute_istft(
         #ifdef USE_OPENMP
         tid = omp_get_thread_num();
         #endif
+        STFTBuffer& buffer = thread_buffers[tid];
         
-        std::vector<Complex>& fft_in = thread_fft_ins[tid];
-        std::vector<float>& frame_out = thread_frame_outs[tid];
+        std::vector<Complex>& fft_in = buffer.fft_in;
+        std::vector<float>& frame_out = buffer.frame_out;
         
         // Extract complex spectrum
         for (int k = 0; k < n_freq; ++k) {
@@ -366,7 +379,7 @@ inline void compute_istft(
         }
         
         // IFFT
-        irfft(fft_in.data(), frame_out.data(), n_fft, &thread_fft_buffers[tid]);
+        irfft(fft_in.data(), frame_out.data(), n_fft, buffer);
         
         // Store
         std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));

+ 1 - 0
tests/CMakeLists.txt

@@ -45,3 +45,4 @@ mbr_add_test(test_component_mask)
 # Integration tests
 mbr_add_test(test_inference)
 mbr_add_test(test_chunking_logic)
+mbr_add_test(test_stft_consistency)

+ 145 - 0
tests/test_stft_consistency.cpp

@@ -0,0 +1,145 @@
+#include "test_common.h"
+#include "../src/stft.h"
+#include "../src/model.h"
+
+int main(int argc, char** argv) {
+    std::cout << "Test: STFT/ISTFT Consistency with PyTorch" << std::endl;
+
+    // 1. Load Model to get parameters
+    std::string model_path = GetModelPath();
+    std::cout << "Loading model params from: " << model_path << std::endl;
+    
+    // We only need the model to read parameters (n_fft, etc.) from GGUF
+    // We don't need to allocate the full graph or weights.
+    BSRoformer model;
+    try {
+        model.Initialize(model_path);
+    } catch (const std::exception& e) {
+        std::cerr << "Failed to load model: " << e.what() << std::endl;
+        std::cerr << "Ensure MBR_MODEL_PATH is set correctly or bs_roformer.gguf exists." << std::endl;
+        return 1;
+    }
+    
+    int n_fft = model.GetNFFT();
+    int hop_length = model.GetHopLength();
+    int win_length = model.GetWinLength();
+    
+    std::cout << "STFT Params: n_fft=" << n_fft << ", hop_length=" << hop_length << ", win_length=" << win_length << std::endl;
+    
+    // 2. Load Data
+    std::string data_dir = GetTestDataDir();
+    std::cout << "Loading test data from: " << data_dir << std::endl;
+    
+    GoldenTensor input_audio(data_dir, "input_audio"); // [batch, channels, samples]
+    GoldenTensor expected_stft(data_dir, "stft_raw"); // [batch, channels, freq, time, 2]
+    GoldenTensor expected_istft(data_dir, "istft_raw"); // [batch, channels, samples]
+    
+    TEST_ASSERT_LOAD(input_audio, "input_audio");
+    TEST_ASSERT_LOAD(expected_stft, "stft_raw");
+    TEST_ASSERT_LOAD(expected_istft, "istft_raw");
+    
+    input_audio.PrintShape("Input Audio");
+    expected_stft.PrintShape("Expected STFT");
+    expected_istft.PrintShape("Expected ISTFT");
+    
+    int batch = input_audio.shape[0];
+    int channels = input_audio.shape[1];
+    int n_samples = input_audio.shape[2];
+    
+    int n_freq = n_fft / 2 + 1;
+    int expected_n_frames = expected_stft.shape[3]; 
+
+    // 3. Prepare Window
+    std::vector<float> window(win_length);
+    stft::hann_window(window.data(), win_length);
+    
+    bool all_passed = true;
+    
+    // 4. Test STFT
+    std::cout << "\n=== Testing STFT ===" << std::endl;
+    
+    for (int b = 0; b < batch; ++b) {
+        for (int c = 0; c < channels; ++c) {
+            // Extract input channel
+            std::vector<float> in_channel(n_samples);
+            for (int i = 0; i < n_samples; ++i) {
+                // Determine index based on memory layout
+                // input_audio.npy is F-contiguous [1, 2, 220500] => [220500, 2] in memory (interleaved)
+                // Layout: L0, R0, L1, R1, ...
+                // Index = (sample_idx * channels + channel_idx)
+                size_t idx = ((size_t)b * n_samples + i) * channels + c;
+                in_channel[i] = input_audio.data[idx];
+            }
+            
+            // Diagnostic: print first few input values
+            std::cout << "  Input[" << b << "," << c << "] first 5: ";
+            for (int i = 0; i < 5; ++i) std::cout << in_channel[i] << " ";
+            std::cout << std::endl;
+            
+            int n_frames_calc = 0;
+            // Buffer for output. 
+            // C++ output is [n_freq, n_frames, 2]
+            std::vector<float> out_stft(n_freq * (expected_n_frames + 10) * 2); 
+            
+            stft::compute_stft(
+                in_channel.data(), n_samples, n_fft, hop_length, win_length,
+                window.data(), true, out_stft.data(), &n_frames_calc
+            );
+            
+            if (n_frames_calc != expected_n_frames) {
+                std::cerr << "  [Batch " << b << " Ch " << c << "] Frame mismatch: calc=" << n_frames_calc << ", expected=" << expected_n_frames << std::endl;
+                all_passed = false;
+                continue;
+            }
+            
+            // Compare
+            size_t channel_stft_size = n_freq * expected_n_frames * 2;
+            size_t offset = b * channels * channel_stft_size + c * channel_stft_size;
+            
+            std::string name = "STFT_B" + std::to_string(b) + "_Ch" + std::to_string(c);
+            if (!CompareAndReport(name, 
+                                  expected_stft.data + offset, channel_stft_size,
+                                  out_stft.data(), channel_stft_size, 1e-3f, 1e-2f)) {
+                all_passed = false;
+            }
+        }
+    }
+    
+    // 5. Test ISTFT
+    std::cout << "\n=== Testing ISTFT ===" << std::endl;
+    
+    for (int b = 0; b < batch; ++b) {
+        for (int c = 0; c < channels; ++c) {
+             size_t channel_stft_size = n_freq * expected_n_frames * 2;
+             size_t offset = b * channels * channel_stft_size + c * channel_stft_size;
+             
+             // Input: expected_stft.data + offset
+             std::vector<float> out_audio(n_samples + n_fft); // Buffer slightly larger
+             
+             // We pass n_samples as expected length
+             stft::compute_istft(
+                 expected_stft.data + offset,
+                 n_freq, expected_n_frames, n_fft, hop_length, win_length,
+                 window.data(), true, n_samples, out_audio.data()
+             );
+             
+             // Verify against expected_istft
+             size_t audio_offset = b * channels * n_samples + c * n_samples;
+             
+             std::string name = "ISTFT_B" + std::to_string(b) + "_Ch" + std::to_string(c);
+             if (!CompareAndReport(name,
+                                   expected_istft.data + audio_offset, n_samples,
+                                   out_audio.data(), n_samples, 1e-4f, 1e-3f)) {
+                 all_passed = false;                     
+             }
+        }
+    }
+    
+    if (all_passed) {
+        LOG_PASS();
+        return 0;
+    } else {
+        LOG_FAIL();
+        return 1;
+    }
+}