stft.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #pragma once
  2. /**
  3. * stft.h - STFT/ISTFT implementation (Optimized)
  4. *
  5. * Implements:
  6. * - Table-based Hann window generation
  7. * - Table-based Radix-2 FFT (Twiddle factors & Bit-reversal)
  8. * - Thread-safe Memory Pooling (STFTBuffer)
  9. * - Center padding (reflect mode)
  10. * - Frame extraction
  11. */
  12. #include <cmath>
  13. #include <vector>
  14. #include <complex>
  15. #include <cstring>
  16. #include <algorithm>
  17. #include <memory>
  18. #include <mutex>
  19. #ifdef USE_OPENMP
  20. #include <omp.h>
  21. #endif
  22. #ifndef M_PI
  23. #define M_PI 3.14159265358979323846
  24. #endif
  25. namespace stft {
  26. using Complex = std::complex<float>;
  27. //=============================================================================
  28. // Memory Pooling
  29. //=============================================================================
  30. /**
  31. * Thread-local buffer storage to avoid frequent allocations in STFT/ISTFT loops.
  32. */
  33. struct STFTBuffer {
  34. // FFT buffers
  35. std::vector<Complex> fft_in;
  36. std::vector<Complex> fft_out;
  37. std::vector<Complex> fft_scratch;
  38. // Frame buffers
  39. std::vector<float> frame_in;
  40. std::vector<float> frame_out;
  41. // Window buffers
  42. std::vector<float> window_padded;
  43. std::vector<float> padded_audio;
  44. void Resize(int n_fft, int padded_len = 0) {
  45. if (fft_in.size() != n_fft) fft_in.resize(n_fft);
  46. if (fft_out.size() != n_fft) fft_out.resize(n_fft);
  47. if (fft_scratch.size() != n_fft) fft_scratch.resize(n_fft);
  48. if (frame_in.size() != n_fft) frame_in.resize(n_fft);
  49. if (frame_out.size() != n_fft) frame_out.resize(n_fft);
  50. if (window_padded.size() != n_fft) window_padded.resize(n_fft);
  51. if (padded_len > 0 && padded_audio.size() < padded_len) padded_audio.resize(padded_len);
  52. }
  53. };
  54. //=============================================================================
  55. // Window Functions
  56. //=============================================================================
  57. inline void hann_window(float* out, int size, bool periodic = true) {
  58. int divisor = periodic ? size : (size - 1);
  59. for (int i = 0; i < size; ++i) {
  60. out[i] = 0.5f * (1.0f - std::cos(2.0f * static_cast<float>(M_PI) * i / divisor));
  61. }
  62. }
  63. //=============================================================================
  64. // FFT Implementation (Table-based Cooley-Tukey Radix-2)
  65. //=============================================================================
  66. class TableFFT {
  67. public:
  68. static TableFFT& GetInstance(int n_fft) {
  69. static std::mutex mtx;
  70. static std::unique_ptr<TableFFT> instance;
  71. static int current_n_fft = -1;
  72. std::lock_guard<std::mutex> lock(mtx);
  73. if (!instance || current_n_fft != n_fft) {
  74. instance = std::make_unique<TableFFT>(n_fft);
  75. current_n_fft = n_fft;
  76. }
  77. return *instance;
  78. }
  79. TableFFT(int n) : n_(n) {
  80. Precomputetables();
  81. }
  82. void Forward(Complex* data) const {
  83. BitReverse(data);
  84. Compute(data, false);
  85. }
  86. void Inverse(Complex* data) const {
  87. BitReverse(data);
  88. Compute(data, true);
  89. // Normalize
  90. float inv_n = 1.0f / n_;
  91. for (int i = 0; i < n_; ++i) {
  92. data[i] *= inv_n;
  93. }
  94. }
  95. private:
  96. int n_;
  97. std::vector<int> bit_reverse_indices_;
  98. std::vector<Complex> twiddles_fwd_;
  99. std::vector<Complex> twiddles_inv_;
  100. void Precomputetables() {
  101. // 1. Bit Reverse
  102. bit_reverse_indices_.resize(n_);
  103. int j = 0;
  104. for (int i = 0; i < n_ - 1; ++i) {
  105. bit_reverse_indices_[i] = (i < j) ? j : i; // Store swap target
  106. int m = n_ >> 1;
  107. while (j >= m && m > 0) {
  108. j -= m;
  109. m >>= 1;
  110. }
  111. j += m;
  112. }
  113. bit_reverse_indices_[n_ - 1] = n_ - 1;
  114. // 2. Twiddles
  115. // We only need twiddles for len = 2, 4, 8 ... n
  116. // Total count is roughly N.
  117. // Structure: [len=2: w], [len=4: w, w^2], ...
  118. // Simplification: Store W_N^k for k=0..N/2-1.
  119. // Then step=N/len.
  120. twiddles_fwd_.resize(n_ / 2);
  121. twiddles_inv_.resize(n_ / 2);
  122. for (int k = 0; k < n_ / 2; ++k) {
  123. float angle = -2.0f * static_cast<float>(M_PI) * k / n_;
  124. twiddles_fwd_[k] = Complex(std::cos(angle), std::sin(angle));
  125. twiddles_inv_[k] = std::conj(twiddles_fwd_[k]);
  126. }
  127. }
  128. void BitReverse(Complex* data) const {
  129. for (int i = 0; i < n_; ++i) {
  130. int j = bit_reverse_indices_[i];
  131. if (i < j) {
  132. std::swap(data[i], data[j]);
  133. }
  134. }
  135. }
  136. void Compute(Complex* data, bool inverse) const {
  137. const auto& twiddles = inverse ? twiddles_inv_ : twiddles_fwd_;
  138. for (int len = 2; len <= n_; len <<= 1) {
  139. int half_len = len >> 1;
  140. int step = n_ / len;
  141. for (int i = 0; i < n_; i += len) {
  142. for (int j = 0; j < half_len; ++j) {
  143. Complex w = twiddles[j * step];
  144. Complex u = data[i + j];
  145. Complex t = w * data[i + j + half_len];
  146. data[i + j] = u + t;
  147. data[i + j + half_len] = u - t;
  148. }
  149. }
  150. }
  151. }
  152. };
  153. //=============================================================================
  154. // STFT Wrapper (Optimized)
  155. //=============================================================================
  156. inline void rfft(const float* input, Complex* output, int n, STFTBuffer& buffer) {
  157. // 1. Copy to complex buffer
  158. for (int i = 0; i < n; ++i) {
  159. buffer.fft_scratch[i] = Complex(input[i], 0.0f);
  160. }
  161. // 2. FFT
  162. TableFFT::GetInstance(n).Forward(buffer.fft_scratch.data());
  163. // 3. Copy first N/2 + 1
  164. int n_out = n / 2 + 1;
  165. for (int i = 0; i < n_out; ++i) {
  166. output[i] = buffer.fft_scratch[i];
  167. }
  168. }
  169. inline void irfft(const Complex* input, float* output, int n_out, STFTBuffer& buffer) {
  170. int n_freq = n_out / 2 + 1;
  171. // 1. Reconstruct full spectrum
  172. for (int i = 0; i < n_freq; ++i) {
  173. buffer.fft_scratch[i] = input[i];
  174. }
  175. for (int i = n_freq; i < n_out; ++i) {
  176. buffer.fft_scratch[i] = std::conj(buffer.fft_scratch[n_out - i]);
  177. }
  178. // 2. IFFT
  179. TableFFT::GetInstance(n_out).Inverse(buffer.fft_scratch.data());
  180. // 3. Real part
  181. for (int i = 0; i < n_out; ++i) {
  182. output[i] = buffer.fft_scratch[i].real();
  183. }
  184. }
  185. inline void compute_stft(
  186. const float* audio,
  187. int n_samples,
  188. int n_fft,
  189. int hop_length,
  190. int win_length,
  191. const float* window,
  192. bool center,
  193. float* output,
  194. int* n_frames_out
  195. ) {
  196. // Center padding
  197. int pad_amount = center ? n_fft / 2 : 0;
  198. int padded_len = n_samples + 2 * pad_amount;
  199. // Calculate number of frames
  200. // PyTorch formula: (L - N) / H + 1
  201. int n_frames = 1 + (padded_len - n_fft) / hop_length;
  202. if (n_frames < 0) n_frames = 0;
  203. *n_frames_out = n_frames;
  204. // Prepare padding buffer (thread-local or single allocation if not parallel?
  205. // Padding + Windowing is usually fast, but padding needs full copy.)
  206. // For safety and simplicity, let's allocate padded audio once here (It's one large buffer).
  207. // The previous implementation used thread_local for 'padded_audio' which is wrong because
  208. // 'padded_audio' needs to hold the WHOLE signal? No, stft.h:52 says 'padded_audio'.
  209. // Analyzing original code: It copied the WHOLE signal to 'padded_audio' inside compute_stft.
  210. // That means 'tls_buffer' was huge! If we have multiple threads, each copying full audio?
  211. // That's wasteful.
  212. // Better: Allocate 'padded' once on heap.
  213. std::vector<float> padded(padded_len);
  214. if (center) {
  215. // Reflect padding
  216. for (int i = 0; i < pad_amount; ++i) {
  217. int src_idx = pad_amount - i;
  218. if (src_idx >= n_samples) src_idx = n_samples - 1;
  219. padded[i] = audio[src_idx];
  220. }
  221. if (n_samples > 0) {
  222. std::memcpy(padded.data() + pad_amount, audio, n_samples * sizeof(float));
  223. }
  224. for (int i = 0; i < pad_amount; ++i) {
  225. int src_idx = n_samples - 2 - i;
  226. if (src_idx < 0) src_idx = 0;
  227. padded[pad_amount + n_samples + i] = audio[src_idx];
  228. }
  229. } else {
  230. std::memcpy(padded.data(), audio, n_samples * sizeof(float));
  231. }
  232. int n_freq = n_fft / 2 + 1;
  233. // Prepare window (Single copy)
  234. std::vector<float> window_padded(n_fft, 0.0f);
  235. if (win_length < n_fft) {
  236. int left = (n_fft - win_length) / 2;
  237. std::memcpy(window_padded.data() + left, window, win_length * sizeof(float));
  238. } else {
  239. std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
  240. }
  241. // Prepare thread buffers
  242. int max_threads = 1;
  243. #ifdef USE_OPENMP
  244. max_threads = omp_get_max_threads();
  245. #endif
  246. std::vector<STFTBuffer> thread_buffers(max_threads);
  247. for(auto& buf : thread_buffers) buf.Resize(n_fft);
  248. // Process each frame
  249. #ifdef USE_OPENMP
  250. #pragma omp parallel for
  251. #endif
  252. for (int f = 0; f < n_frames; ++f) {
  253. int tid = 0;
  254. #ifdef USE_OPENMP
  255. tid = omp_get_thread_num();
  256. #endif
  257. STFTBuffer& buffer = thread_buffers[tid];
  258. std::vector<float>& frame = buffer.frame_in;
  259. int start = f * hop_length;
  260. for (int i = 0; i < n_fft; ++i) {
  261. frame[i] = padded[start + i] * window_padded[i];
  262. }
  263. // Compute FFT
  264. // Output pointer directly to destination
  265. // We need a place to store complex output before writing to planar output
  266. rfft(frame.data(), buffer.fft_out.data(), n_fft, buffer);
  267. // Write to output
  268. for (int k = 0; k < n_freq; ++k) {
  269. output[(k * n_frames + f) * 2 + 0] = buffer.fft_out[k].real();
  270. output[(k * n_frames + f) * 2 + 1] = buffer.fft_out[k].imag();
  271. }
  272. }
  273. }
  274. inline void compute_istft(
  275. const float* stft_data,
  276. int n_freq,
  277. int n_frames,
  278. int n_fft,
  279. int hop_length,
  280. int win_length,
  281. const float* window,
  282. bool center,
  283. int length,
  284. float* output
  285. ) {
  286. // Calculate expected output signal length
  287. int expected_len = n_fft + hop_length * (n_frames - 1);
  288. int pad_amount = center ? n_fft / 2 : 0;
  289. int output_len = (length > 0) ? length : (expected_len - 2 * pad_amount);
  290. // Prepare padded window
  291. std::vector<float> window_padded(n_fft, 0.0f);
  292. if (win_length < n_fft) {
  293. int left = (n_fft - win_length) / 2;
  294. std::memcpy(window_padded.data() + left, window, win_length * sizeof(float));
  295. } else {
  296. std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
  297. }
  298. // Prepare thread buffers
  299. int max_threads = 1;
  300. #ifdef USE_OPENMP
  301. max_threads = omp_get_max_threads();
  302. #endif
  303. std::vector<STFTBuffer> thread_buffers(max_threads);
  304. for(auto& buf : thread_buffers) buf.Resize(n_fft);
  305. // Step 1: Compute all IFFTs in parallel
  306. std::vector<float> frames_time_domain(n_frames * n_fft);
  307. #ifdef USE_OPENMP
  308. #pragma omp parallel for
  309. #endif
  310. for (int f = 0; f < n_frames; ++f) {
  311. int tid = 0;
  312. #ifdef USE_OPENMP
  313. tid = omp_get_thread_num();
  314. #endif
  315. STFTBuffer& buffer = thread_buffers[tid];
  316. std::vector<Complex>& fft_in = buffer.fft_in;
  317. std::vector<float>& frame_out = buffer.frame_out;
  318. // Extract complex spectrum
  319. for (int k = 0; k < n_freq; ++k) {
  320. float re = stft_data[(k * n_frames + f) * 2 + 0];
  321. float im = stft_data[(k * n_frames + f) * 2 + 1];
  322. fft_in[k] = Complex(re, im);
  323. }
  324. // IFFT
  325. irfft(fft_in.data(), frame_out.data(), n_fft, buffer);
  326. // Store
  327. std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));
  328. }
  329. // Step 2: Overlap Add (Serial)
  330. std::vector<float> y(expected_len, 0.0f);
  331. std::vector<float> window_sum(expected_len, 0.0f);
  332. for (int f = 0; f < n_frames; ++f) {
  333. int start = f * hop_length;
  334. const float* frame_ptr = &frames_time_domain[f * n_fft];
  335. for (int i = 0; i < n_fft; ++i) {
  336. y[start + i] += frame_ptr[i] * window_padded[i];
  337. window_sum[start + i] += window_padded[i] * window_padded[i];
  338. }
  339. }
  340. // Normalize by window sum (avoid division by zero)
  341. for (int i = 0; i < expected_len; ++i) {
  342. if (window_sum[i] > 1e-8f) {
  343. y[i] /= window_sum[i];
  344. }
  345. }
  346. // Remove center padding and copy to output
  347. for (int i = 0; i < output_len; ++i) {
  348. if (pad_amount + i < expected_len) {
  349. output[i] = y[pad_amount + i];
  350. } else {
  351. output[i] = 0.0f;
  352. }
  353. }
  354. }
  355. } // namespace stft