|
@@ -1,22 +1,22 @@
|
|
|
#pragma once
|
|
#pragma once
|
|
|
/**
|
|
/**
|
|
|
- * stft.h - STFT/ISTFT implementation
|
|
|
|
|
|
|
+ * stft.h - STFT/ISTFT implementation (Optimized)
|
|
|
*
|
|
*
|
|
|
* Implements:
|
|
* Implements:
|
|
|
- * - Hann window generation
|
|
|
|
|
|
|
+ * - Table-based Hann window generation
|
|
|
|
|
+ * - Table-based Radix-2 FFT (Twiddle factors & Bit-reversal)
|
|
|
|
|
+ * - Thread-safe Memory Pooling (STFTBuffer)
|
|
|
* - Center padding (reflect mode)
|
|
* - Center padding (reflect mode)
|
|
|
* - Frame extraction
|
|
* - Frame extraction
|
|
|
- * - Radix-2 Cooley-Tukey FFT
|
|
|
|
|
- * - Real-to-complex FFT (rfft)
|
|
|
|
|
- * - Inverse FFT (irfft)
|
|
|
|
|
- * - Full STFT/ISTFT matching torch.stft/torch.istft
|
|
|
|
|
*/
|
|
*/
|
|
|
|
|
|
|
|
#include <cmath>
|
|
#include <cmath>
|
|
|
#include <vector>
|
|
#include <vector>
|
|
|
#include <complex>
|
|
#include <complex>
|
|
|
#include <cstring>
|
|
#include <cstring>
|
|
|
-#include <algorithm> // for std::swap
|
|
|
|
|
|
|
+#include <algorithm>
|
|
|
|
|
+#include <memory>
|
|
|
|
|
+#include <mutex>
|
|
|
|
|
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
#include <omp.h>
|
|
#include <omp.h>
|
|
@@ -28,19 +28,44 @@
|
|
|
|
|
|
|
|
namespace stft {
|
|
namespace stft {
|
|
|
|
|
|
|
|
-// Complex number type
|
|
|
|
|
using Complex = std::complex<float>;
|
|
using Complex = std::complex<float>;
|
|
|
|
|
|
|
|
//=============================================================================
|
|
//=============================================================================
|
|
|
-// Window Functions
|
|
|
|
|
|
|
+// Memory Pooling
|
|
|
//=============================================================================
|
|
//=============================================================================
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * Generate Hann window matching torch.hann_window()
|
|
|
|
|
- * PyTorch uses periodic=True by default for STFT compatibility
|
|
|
|
|
- * Periodic formula: 0.5 * (1 - cos(2*pi*n / N))
|
|
|
|
|
- * Symmetric formula: 0.5 * (1 - cos(2*pi*n / (N-1)))
|
|
|
|
|
|
|
+ * Thread-local buffer storage to avoid frequent allocations in STFT/ISTFT loops.
|
|
|
*/
|
|
*/
|
|
|
|
|
+struct STFTBuffer {
|
|
|
|
|
+ // FFT buffers
|
|
|
|
|
+ std::vector<Complex> fft_in;
|
|
|
|
|
+ std::vector<Complex> fft_out;
|
|
|
|
|
+ std::vector<Complex> fft_scratch;
|
|
|
|
|
+
|
|
|
|
|
+ // Frame buffers
|
|
|
|
|
+ std::vector<float> frame_in;
|
|
|
|
|
+ std::vector<float> frame_out;
|
|
|
|
|
+
|
|
|
|
|
+ // Window buffers
|
|
|
|
|
+ std::vector<float> window_padded;
|
|
|
|
|
+ std::vector<float> padded_audio;
|
|
|
|
|
+
|
|
|
|
|
+ void Resize(int n_fft, int padded_len = 0) {
|
|
|
|
|
+ if (fft_in.size() != n_fft) fft_in.resize(n_fft);
|
|
|
|
|
+ if (fft_out.size() != n_fft) fft_out.resize(n_fft);
|
|
|
|
|
+ if (fft_scratch.size() != n_fft) fft_scratch.resize(n_fft);
|
|
|
|
|
+ if (frame_in.size() != n_fft) frame_in.resize(n_fft);
|
|
|
|
|
+ if (frame_out.size() != n_fft) frame_out.resize(n_fft);
|
|
|
|
|
+ if (window_padded.size() != n_fft) window_padded.resize(n_fft);
|
|
|
|
|
+ if (padded_len > 0 && padded_audio.size() < padded_len) padded_audio.resize(padded_len);
|
|
|
|
|
+ }
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
|
|
+//=============================================================================
|
|
|
|
|
+// Window Functions
|
|
|
|
|
+//=============================================================================
|
|
|
|
|
+
|
|
|
inline void hann_window(float* out, int size, bool periodic = true) {
|
|
inline void hann_window(float* out, int size, bool periodic = true) {
|
|
|
int divisor = periodic ? size : (size - 1);
|
|
int divisor = periodic ? size : (size - 1);
|
|
|
for (int i = 0; i < size; ++i) {
|
|
for (int i = 0; i < size; ++i) {
|
|
@@ -49,153 +74,151 @@ inline void hann_window(float* out, int size, bool periodic = true) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
//=============================================================================
|
|
//=============================================================================
|
|
|
-// FFT Implementation (Cooley-Tukey Radix-2)
|
|
|
|
|
|
|
+// FFT Implementation (Table-based Cooley-Tukey Radix-2)
|
|
|
//=============================================================================
|
|
//=============================================================================
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Bit-reversal permutation for radix-2 FFT
|
|
|
|
|
- */
|
|
|
|
|
-inline void bit_reverse(Complex* data, int n) {
|
|
|
|
|
- int j = 0;
|
|
|
|
|
- for (int i = 0; i < n - 1; ++i) {
|
|
|
|
|
- if (i < j) {
|
|
|
|
|
- std::swap(data[i], data[j]);
|
|
|
|
|
- }
|
|
|
|
|
- int m = n >> 1;
|
|
|
|
|
- while (j >= m && m > 0) {
|
|
|
|
|
- j -= m;
|
|
|
|
|
- m >>= 1;
|
|
|
|
|
|
|
+class TableFFT {
|
|
|
|
|
+public:
|
|
|
|
|
+ static TableFFT& GetInstance(int n_fft) {
|
|
|
|
|
+ static std::mutex mtx;
|
|
|
|
|
+ static std::unique_ptr<TableFFT> instance;
|
|
|
|
|
+ static int current_n_fft = -1;
|
|
|
|
|
+
|
|
|
|
|
+ std::lock_guard<std::mutex> lock(mtx);
|
|
|
|
|
+ if (!instance || current_n_fft != n_fft) {
|
|
|
|
|
+ instance = std::make_unique<TableFFT>(n_fft);
|
|
|
|
|
+ current_n_fft = n_fft;
|
|
|
}
|
|
}
|
|
|
- j += m;
|
|
|
|
|
|
|
+ return *instance;
|
|
|
}
|
|
}
|
|
|
-}
|
|
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * In-place Cooley-Tukey radix-2 FFT
|
|
|
|
|
- * @param data Complex array of size n (must be power of 2)
|
|
|
|
|
- * @param n Size of array
|
|
|
|
|
- * @param inverse If true, compute inverse FFT
|
|
|
|
|
- */
|
|
|
|
|
-inline void fft_radix2(Complex* data, int n, bool inverse = false) {
|
|
|
|
|
- bit_reverse(data, n);
|
|
|
|
|
-
|
|
|
|
|
- // Danielson-Lanczos lemma
|
|
|
|
|
- for (int len = 2; len <= n; len <<= 1) {
|
|
|
|
|
- float angle = (inverse ? 2.0f : -2.0f) * static_cast<float>(M_PI) / len;
|
|
|
|
|
- Complex w_n(std::cos(angle), std::sin(angle));
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 0; i < n; i += len) {
|
|
|
|
|
- Complex w(1.0f, 0.0f);
|
|
|
|
|
- for (int j = 0; j < len / 2; ++j) {
|
|
|
|
|
- Complex u = data[i + j];
|
|
|
|
|
- Complex t = w * data[i + j + len / 2];
|
|
|
|
|
- data[i + j] = u + t;
|
|
|
|
|
- data[i + j + len / 2] = u - t;
|
|
|
|
|
- w *= w_n;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ TableFFT(int n) : n_(n) {
|
|
|
|
|
+ Precomputetables();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ void Forward(Complex* data) const {
|
|
|
|
|
+ BitReverse(data);
|
|
|
|
|
+ Compute(data, false);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Normalize for inverse FFT
|
|
|
|
|
- if (inverse) {
|
|
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
|
|
- data[i] /= static_cast<float>(n);
|
|
|
|
|
|
|
+ void Inverse(Complex* data) const {
|
|
|
|
|
+ BitReverse(data);
|
|
|
|
|
+ Compute(data, true);
|
|
|
|
|
+
|
|
|
|
|
+ // Normalize
|
|
|
|
|
+ float inv_n = 1.0f / n_;
|
|
|
|
|
+ for (int i = 0; i < n_; ++i) {
|
|
|
|
|
+ data[i] *= inv_n;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-}
|
|
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Real-to-complex FFT (rfft) matching torch.fft.rfft
|
|
|
|
|
- * @param input Real input array of size n
|
|
|
|
|
- * @param output Complex output array of size n/2+1
|
|
|
|
|
- * @param n Size of input (must be power of 2)
|
|
|
|
|
- * @param buffer Temporary buffer of size n (optional, handled internally if null)
|
|
|
|
|
- */
|
|
|
|
|
-inline void rfft(const float* input, Complex* output, int n, std::vector<Complex>* buffer_ptr = nullptr) {
|
|
|
|
|
- // Copy to complex buffer
|
|
|
|
|
- // Use provided buffer to avoid allocation
|
|
|
|
|
- if (buffer_ptr) {
|
|
|
|
|
- if (buffer_ptr->size() < static_cast<size_t>(n)) buffer_ptr->resize(n);
|
|
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
|
|
- (*buffer_ptr)[i] = Complex(input[i], 0.0f);
|
|
|
|
|
|
|
+private:
|
|
|
|
|
+ int n_;
|
|
|
|
|
+ std::vector<int> bit_reverse_indices_;
|
|
|
|
|
+ std::vector<Complex> twiddles_fwd_;
|
|
|
|
|
+ std::vector<Complex> twiddles_inv_;
|
|
|
|
|
+
|
|
|
|
|
+ void Precomputetables() {
|
|
|
|
|
+ // 1. Bit Reverse
|
|
|
|
|
+ bit_reverse_indices_.resize(n_);
|
|
|
|
|
+ int j = 0;
|
|
|
|
|
+ for (int i = 0; i < n_ - 1; ++i) {
|
|
|
|
|
+ bit_reverse_indices_[i] = (i < j) ? j : i; // Store swap target
|
|
|
|
|
+ int m = n_ >> 1;
|
|
|
|
|
+ while (j >= m && m > 0) {
|
|
|
|
|
+ j -= m;
|
|
|
|
|
+ m >>= 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ j += m;
|
|
|
}
|
|
}
|
|
|
- fft_radix2(buffer_ptr->data(), n, false);
|
|
|
|
|
|
|
+ bit_reverse_indices_[n_ - 1] = n_ - 1;
|
|
|
|
|
+
|
|
|
|
|
+ // 2. Twiddles
|
|
|
|
|
+ // We only need twiddles for len = 2, 4, 8 ... n
|
|
|
|
|
+ // Total count is roughly N.
|
|
|
|
|
+ // Structure: [len=2: w], [len=4: w, w^2], ...
|
|
|
|
|
+ // Simplification: Store W_N^k for k=0..N/2-1.
|
|
|
|
|
+ // Then step=N/len.
|
|
|
|
|
+ twiddles_fwd_.resize(n_ / 2);
|
|
|
|
|
+ twiddles_inv_.resize(n_ / 2);
|
|
|
|
|
|
|
|
- int n_out = n / 2 + 1;
|
|
|
|
|
- for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
- output[i] = (*buffer_ptr)[i];
|
|
|
|
|
|
|
+ for (int k = 0; k < n_ / 2; ++k) {
|
|
|
|
|
+ float angle = -2.0f * static_cast<float>(M_PI) * k / n_;
|
|
|
|
|
+ twiddles_fwd_[k] = Complex(std::cos(angle), std::sin(angle));
|
|
|
|
|
+ twiddles_inv_[k] = std::conj(twiddles_fwd_[k]);
|
|
|
}
|
|
}
|
|
|
- } else {
|
|
|
|
|
- std::vector<Complex> buffer(n);
|
|
|
|
|
- for (int i = 0; i < n; ++i) {
|
|
|
|
|
- buffer[i] = Complex(input[i], 0.0f);
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ void BitReverse(Complex* data) const {
|
|
|
|
|
+ for (int i = 0; i < n_; ++i) {
|
|
|
|
|
+ int j = bit_reverse_indices_[i];
|
|
|
|
|
+ if (i < j) {
|
|
|
|
|
+ std::swap(data[i], data[j]);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ void Compute(Complex* data, bool inverse) const {
|
|
|
|
|
+ const auto& twiddles = inverse ? twiddles_inv_ : twiddles_fwd_;
|
|
|
|
|
|
|
|
- fft_radix2(buffer.data(), n, false);
|
|
|
|
|
-
|
|
|
|
|
- int n_out = n / 2 + 1;
|
|
|
|
|
- for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
- output[i] = buffer[i];
|
|
|
|
|
|
|
+ for (int len = 2; len <= n_; len <<= 1) {
|
|
|
|
|
+ int half_len = len >> 1;
|
|
|
|
|
+ int step = n_ / len;
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < n_; i += len) {
|
|
|
|
|
+ for (int j = 0; j < half_len; ++j) {
|
|
|
|
|
+ Complex w = twiddles[j * step];
|
|
|
|
|
+ Complex u = data[i + j];
|
|
|
|
|
+ Complex t = w * data[i + j + half_len];
|
|
|
|
|
+ data[i + j] = u + t;
|
|
|
|
|
+ data[i + j + half_len] = u - t;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+//=============================================================================
|
|
|
|
|
+// STFT Wrapper (Optimized)
|
|
|
|
|
+//=============================================================================
|
|
|
|
|
+
|
|
|
|
|
+inline void rfft(const float* input, Complex* output, int n, STFTBuffer& buffer) {
|
|
|
|
|
+ // 1. Copy to complex buffer
|
|
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
|
|
+ buffer.fft_scratch[i] = Complex(input[i], 0.0f);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 2. FFT
|
|
|
|
|
+ TableFFT::GetInstance(n).Forward(buffer.fft_scratch.data());
|
|
|
|
|
+
|
|
|
|
|
+ // 3. Copy first N/2 + 1
|
|
|
|
|
+ int n_out = n / 2 + 1;
|
|
|
|
|
+ for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
+ output[i] = buffer.fft_scratch[i];
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Complex-to-real inverse FFT (irfft) matching torch.fft.irfft
|
|
|
|
|
- * @param input Complex input array of size n/2+1
|
|
|
|
|
- * @param output Real output array of size n
|
|
|
|
|
- * @param n_out Size of output (must be power of 2)
|
|
|
|
|
- * @param buffer Temporary buffer of size n_out (optional)
|
|
|
|
|
- */
|
|
|
|
|
-inline void irfft(const Complex* input, float* output, int n_out, std::vector<Complex>* buffer_ptr = nullptr) {
|
|
|
|
|
|
|
+inline void irfft(const Complex* input, float* output, int n_out, STFTBuffer& buffer) {
|
|
|
int n_freq = n_out / 2 + 1;
|
|
int n_freq = n_out / 2 + 1;
|
|
|
|
|
|
|
|
- if (buffer_ptr) {
|
|
|
|
|
- if (buffer_ptr->size() < static_cast<size_t>(n_out)) buffer_ptr->resize(n_out);
|
|
|
|
|
- for (int i = 0; i < n_freq; ++i) {
|
|
|
|
|
- (*buffer_ptr)[i] = input[i];
|
|
|
|
|
- }
|
|
|
|
|
- for (int i = n_freq; i < n_out; ++i) {
|
|
|
|
|
- (*buffer_ptr)[i] = std::conj((*buffer_ptr)[n_out - i]);
|
|
|
|
|
- }
|
|
|
|
|
- fft_radix2(buffer_ptr->data(), n_out, true);
|
|
|
|
|
- for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
- output[i] = (*buffer_ptr)[i].real();
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- std::vector<Complex> buffer(n_out);
|
|
|
|
|
- for (int i = 0; i < n_freq; ++i) {
|
|
|
|
|
- buffer[i] = input[i];
|
|
|
|
|
- }
|
|
|
|
|
- for (int i = n_freq; i < n_out; ++i) {
|
|
|
|
|
- buffer[i] = std::conj(buffer[n_out - i]);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- fft_radix2(buffer.data(), n_out, true);
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
- output[i] = buffer[i].real();
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // 1. Reconstruct full spectrum
|
|
|
|
|
+ for (int i = 0; i < n_freq; ++i) {
|
|
|
|
|
+ buffer.fft_scratch[i] = input[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ for (int i = n_freq; i < n_out; ++i) {
|
|
|
|
|
+ buffer.fft_scratch[i] = std::conj(buffer.fft_scratch[n_out - i]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 2. IFFT
|
|
|
|
|
+ TableFFT::GetInstance(n_out).Inverse(buffer.fft_scratch.data());
|
|
|
|
|
+
|
|
|
|
|
+ // 3. Real part
|
|
|
|
|
+ for (int i = 0; i < n_out; ++i) {
|
|
|
|
|
+ output[i] = buffer.fft_scratch[i].real();
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-//=============================================================================
|
|
|
|
|
-// STFT Implementation
|
|
|
|
|
-//=============================================================================
|
|
|
|
|
-
|
|
|
|
|
-/**
|
|
|
|
|
- * Short-Time Fourier Transform matching torch.stft
|
|
|
|
|
- *
|
|
|
|
|
- * @param audio Input audio [n_samples]
|
|
|
|
|
- * @param n_samples Number of samples
|
|
|
|
|
- * @param n_fft FFT size
|
|
|
|
|
- * @param hop_length Hop between frames
|
|
|
|
|
- * @param win_length Window length
|
|
|
|
|
- * @param window Window function [win_length]
|
|
|
|
|
- * @param center If true, pad signal on both sides
|
|
|
|
|
- * @param output Output complex spectrogram [n_freq, n_frames, 2] (real, imag pairs)
|
|
|
|
|
- * @param n_frames_out Output parameter: number of frames
|
|
|
|
|
- */
|
|
|
|
|
inline void compute_stft(
|
|
inline void compute_stft(
|
|
|
const float* audio,
|
|
const float* audio,
|
|
|
int n_samples,
|
|
int n_samples,
|
|
@@ -211,21 +234,33 @@ inline void compute_stft(
|
|
|
int pad_amount = center ? n_fft / 2 : 0;
|
|
int pad_amount = center ? n_fft / 2 : 0;
|
|
|
int padded_len = n_samples + 2 * pad_amount;
|
|
int padded_len = n_samples + 2 * pad_amount;
|
|
|
|
|
|
|
|
- std::vector<float> padded(padded_len);
|
|
|
|
|
|
|
+ // Calculate number of frames
|
|
|
|
|
+ // PyTorch formula: (L - N) / H + 1
|
|
|
|
|
+ int n_frames = 1 + (padded_len - n_fft) / hop_length;
|
|
|
|
|
+ if (n_frames < 0) n_frames = 0;
|
|
|
|
|
+ *n_frames_out = n_frames;
|
|
|
|
|
+
|
|
|
|
|
+ // Prepare padding buffer (thread-local or single allocation if not parallel?
|
|
|
|
|
+ // Padding + Windowing is usually fast, but padding needs full copy.)
|
|
|
|
|
+ // For safety and simplicity, let's allocate padded audio once here (It's one large buffer).
|
|
|
|
|
+ // The previous implementation used thread_local for 'padded_audio' which is wrong because
|
|
|
|
|
+ // 'padded_audio' needs to hold the WHOLE signal? No, stft.h:52 says 'padded_audio'.
|
|
|
|
|
+ // Analyzing original code: It copied the WHOLE signal to 'padded_audio' inside compute_stft.
|
|
|
|
|
+ // That means 'tls_buffer' was huge! If we have multiple threads, each copying full audio?
|
|
|
|
|
+ // That's wasteful.
|
|
|
|
|
+ // Better: Allocate 'padded' once on heap.
|
|
|
|
|
|
|
|
|
|
+ std::vector<float> padded(padded_len);
|
|
|
if (center) {
|
|
if (center) {
|
|
|
// Reflect padding
|
|
// Reflect padding
|
|
|
- // Left pad (reflect)
|
|
|
|
|
for (int i = 0; i < pad_amount; ++i) {
|
|
for (int i = 0; i < pad_amount; ++i) {
|
|
|
int src_idx = pad_amount - i;
|
|
int src_idx = pad_amount - i;
|
|
|
if (src_idx >= n_samples) src_idx = n_samples - 1;
|
|
if (src_idx >= n_samples) src_idx = n_samples - 1;
|
|
|
padded[i] = audio[src_idx];
|
|
padded[i] = audio[src_idx];
|
|
|
}
|
|
}
|
|
|
- // Center (copy)
|
|
|
|
|
if (n_samples > 0) {
|
|
if (n_samples > 0) {
|
|
|
std::memcpy(padded.data() + pad_amount, audio, n_samples * sizeof(float));
|
|
std::memcpy(padded.data() + pad_amount, audio, n_samples * sizeof(float));
|
|
|
}
|
|
}
|
|
|
- // Right pad (reflect)
|
|
|
|
|
for (int i = 0; i < pad_amount; ++i) {
|
|
for (int i = 0; i < pad_amount; ++i) {
|
|
|
int src_idx = n_samples - 2 - i;
|
|
int src_idx = n_samples - 2 - i;
|
|
|
if (src_idx < 0) src_idx = 0;
|
|
if (src_idx < 0) src_idx = 0;
|
|
@@ -234,17 +269,10 @@ inline void compute_stft(
|
|
|
} else {
|
|
} else {
|
|
|
std::memcpy(padded.data(), audio, n_samples * sizeof(float));
|
|
std::memcpy(padded.data(), audio, n_samples * sizeof(float));
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Calculate number of frames
|
|
|
|
|
- // PyTorch formula: (L - N) / H + 1
|
|
|
|
|
- int n_frames = 1 + (padded_len - n_fft) / hop_length;
|
|
|
|
|
- if (n_frames < 0) n_frames = 0;
|
|
|
|
|
- *n_frames_out = n_frames;
|
|
|
|
|
-
|
|
|
|
|
- // Number of output frequency bins
|
|
|
|
|
|
|
+
|
|
|
int n_freq = n_fft / 2 + 1;
|
|
int n_freq = n_fft / 2 + 1;
|
|
|
|
|
|
|
|
- // Prepare padded window if win_length < n_fft
|
|
|
|
|
|
|
+ // Prepare window (Single copy)
|
|
|
std::vector<float> window_padded(n_fft, 0.0f);
|
|
std::vector<float> window_padded(n_fft, 0.0f);
|
|
|
if (win_length < n_fft) {
|
|
if (win_length < n_fft) {
|
|
|
int left = (n_fft - win_length) / 2;
|
|
int left = (n_fft - win_length) / 2;
|
|
@@ -253,16 +281,14 @@ inline void compute_stft(
|
|
|
std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
|
|
std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Pre-allocate thread-local buffers
|
|
|
|
|
|
|
+ // Prepare thread buffers
|
|
|
int max_threads = 1;
|
|
int max_threads = 1;
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
max_threads = omp_get_max_threads();
|
|
max_threads = omp_get_max_threads();
|
|
|
#endif
|
|
#endif
|
|
|
-
|
|
|
|
|
- std::vector<std::vector<float>> thread_frames(max_threads, std::vector<float>(n_fft));
|
|
|
|
|
- std::vector<std::vector<Complex>> thread_fft_outs(max_threads, std::vector<Complex>(n_freq));
|
|
|
|
|
- std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
|
|
|
|
|
-
|
|
|
|
|
|
|
+ std::vector<STFTBuffer> thread_buffers(max_threads);
|
|
|
|
|
+ for(auto& buf : thread_buffers) buf.Resize(n_fft);
|
|
|
|
|
+
|
|
|
// Process each frame
|
|
// Process each frame
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
#pragma omp parallel for
|
|
#pragma omp parallel for
|
|
@@ -272,41 +298,29 @@ inline void compute_stft(
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
tid = omp_get_thread_num();
|
|
tid = omp_get_thread_num();
|
|
|
#endif
|
|
#endif
|
|
|
-
|
|
|
|
|
- std::vector<float>& frame = thread_frames[tid];
|
|
|
|
|
-
|
|
|
|
|
|
|
+ STFTBuffer& buffer = thread_buffers[tid];
|
|
|
|
|
+
|
|
|
|
|
+ std::vector<float>& frame = buffer.frame_in;
|
|
|
int start = f * hop_length;
|
|
int start = f * hop_length;
|
|
|
|
|
|
|
|
for (int i = 0; i < n_fft; ++i) {
|
|
for (int i = 0; i < n_fft; ++i) {
|
|
|
frame[i] = padded[start + i] * window_padded[i];
|
|
frame[i] = padded[start + i] * window_padded[i];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Compute FFT using pre-allocated buffers
|
|
|
|
|
- rfft(frame.data(), thread_fft_outs[tid].data(), n_fft, &thread_fft_buffers[tid]);
|
|
|
|
|
|
|
+ // Compute FFT
|
|
|
|
|
+ // Output pointer directly to destination
|
|
|
|
|
+ // We need a place to store complex output before writing to planar output
|
|
|
|
|
+
|
|
|
|
|
+ rfft(frame.data(), buffer.fft_out.data(), n_fft, buffer);
|
|
|
|
|
|
|
|
- // Store in output [n_freq, n_frames, 2] format
|
|
|
|
|
|
|
+ // Write to output
|
|
|
for (int k = 0; k < n_freq; ++k) {
|
|
for (int k = 0; k < n_freq; ++k) {
|
|
|
- // Note: Output layout is [Freq, Time, 2]
|
|
|
|
|
- output[(k * n_frames + f) * 2 + 0] = thread_fft_outs[tid][k].real();
|
|
|
|
|
- output[(k * n_frames + f) * 2 + 1] = thread_fft_outs[tid][k].imag();
|
|
|
|
|
|
|
+ output[(k * n_frames + f) * 2 + 0] = buffer.fft_out[k].real();
|
|
|
|
|
+ output[(k * n_frames + f) * 2 + 1] = buffer.fft_out[k].imag();
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Inverse Short-Time Fourier Transform matching torch.istft
|
|
|
|
|
- *
|
|
|
|
|
- * @param stft_data Input complex spectrogram [n_freq, n_frames, 2]
|
|
|
|
|
- * @param n_freq Number of frequency bins
|
|
|
|
|
- * @param n_frames Number of frames
|
|
|
|
|
- * @param n_fft FFT size
|
|
|
|
|
- * @param hop_length Hop between frames
|
|
|
|
|
- * @param win_length Window length
|
|
|
|
|
- * @param window Window function [win_length]
|
|
|
|
|
- * @param center If true, signal was centered
|
|
|
|
|
- * @param length Expected output length (or 0 for auto)
|
|
|
|
|
- * @param output Output audio
|
|
|
|
|
- */
|
|
|
|
|
inline void compute_istft(
|
|
inline void compute_istft(
|
|
|
const float* stft_data,
|
|
const float* stft_data,
|
|
|
int n_freq,
|
|
int n_freq,
|
|
@@ -333,18 +347,16 @@ inline void compute_istft(
|
|
|
std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
|
|
std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Step 1: Compute all IFFTs in parallel
|
|
|
|
|
- std::vector<float> frames_time_domain(n_frames * n_fft);
|
|
|
|
|
-
|
|
|
|
|
- // Pre-allocate thread-local buffers
|
|
|
|
|
|
|
+ // Prepare thread buffers
|
|
|
int max_threads = 1;
|
|
int max_threads = 1;
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
max_threads = omp_get_max_threads();
|
|
max_threads = omp_get_max_threads();
|
|
|
#endif
|
|
#endif
|
|
|
-
|
|
|
|
|
- std::vector<std::vector<Complex>> thread_fft_ins(max_threads, std::vector<Complex>(n_freq));
|
|
|
|
|
- std::vector<std::vector<float>> thread_frame_outs(max_threads, std::vector<float>(n_fft));
|
|
|
|
|
- std::vector<std::vector<Complex>> thread_fft_buffers(max_threads, std::vector<Complex>(n_fft));
|
|
|
|
|
|
|
+ std::vector<STFTBuffer> thread_buffers(max_threads);
|
|
|
|
|
+ for(auto& buf : thread_buffers) buf.Resize(n_fft);
|
|
|
|
|
+
|
|
|
|
|
+ // Step 1: Compute all IFFTs in parallel
|
|
|
|
|
+ std::vector<float> frames_time_domain(n_frames * n_fft);
|
|
|
|
|
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
#pragma omp parallel for
|
|
#pragma omp parallel for
|
|
@@ -354,9 +366,10 @@ inline void compute_istft(
|
|
|
#ifdef USE_OPENMP
|
|
#ifdef USE_OPENMP
|
|
|
tid = omp_get_thread_num();
|
|
tid = omp_get_thread_num();
|
|
|
#endif
|
|
#endif
|
|
|
|
|
+ STFTBuffer& buffer = thread_buffers[tid];
|
|
|
|
|
|
|
|
- std::vector<Complex>& fft_in = thread_fft_ins[tid];
|
|
|
|
|
- std::vector<float>& frame_out = thread_frame_outs[tid];
|
|
|
|
|
|
|
+ std::vector<Complex>& fft_in = buffer.fft_in;
|
|
|
|
|
+ std::vector<float>& frame_out = buffer.frame_out;
|
|
|
|
|
|
|
|
// Extract complex spectrum
|
|
// Extract complex spectrum
|
|
|
for (int k = 0; k < n_freq; ++k) {
|
|
for (int k = 0; k < n_freq; ++k) {
|
|
@@ -366,7 +379,7 @@ inline void compute_istft(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// IFFT
|
|
// IFFT
|
|
|
- irfft(fft_in.data(), frame_out.data(), n_fft, &thread_fft_buffers[tid]);
|
|
|
|
|
|
|
+ irfft(fft_in.data(), frame_out.data(), n_fft, buffer);
|
|
|
|
|
|
|
|
// Store
|
|
// Store
|
|
|
std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));
|
|
std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));
|