stft.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. #pragma once
  2. /**
  3. * stft.h - STFT/ISTFT implementation
  4. *
  5. * Implements:
  6. * - Hann window generation
  7. * - Center padding (reflect mode)
  8. * - Frame extraction
  9. * - Radix-2 Cooley-Tukey FFT
  10. * - Real-to-complex FFT (rfft)
  11. * - Inverse FFT (irfft)
  12. * - Full STFT/ISTFT matching torch.stft/torch.istft
  13. */
  14. #include <cmath>
  15. #include <vector>
  16. #include <complex>
  17. #include <cstring>
  18. #include <algorithm> // for std::swap
  19. #ifndef M_PI
  20. #define M_PI 3.14159265358979323846
  21. #endif
  22. namespace stft {
  23. // Complex number type
  24. using Complex = std::complex<float>;
  25. //=============================================================================
  26. // Window Functions
  27. //=============================================================================
  28. /**
  29. * Generate Hann window matching torch.hann_window()
  30. * PyTorch uses periodic=True by default for STFT compatibility
  31. * Periodic formula: 0.5 * (1 - cos(2*pi*n / N))
  32. * Symmetric formula: 0.5 * (1 - cos(2*pi*n / (N-1)))
  33. */
  34. inline void hann_window(float* out, int size, bool periodic = true) {
  35. int divisor = periodic ? size : (size - 1);
  36. for (int i = 0; i < size; ++i) {
  37. out[i] = 0.5f * (1.0f - std::cos(2.0f * static_cast<float>(M_PI) * i / divisor));
  38. }
  39. }
  40. //=============================================================================
  41. // FFT Implementation (Cooley-Tukey Radix-2)
  42. //=============================================================================
  43. /**
  44. * Bit-reversal permutation for radix-2 FFT
  45. */
  46. inline void bit_reverse(Complex* data, int n) {
  47. int j = 0;
  48. for (int i = 0; i < n - 1; ++i) {
  49. if (i < j) {
  50. std::swap(data[i], data[j]);
  51. }
  52. int m = n >> 1;
  53. while (j >= m && m > 0) {
  54. j -= m;
  55. m >>= 1;
  56. }
  57. j += m;
  58. }
  59. }
  60. /**
  61. * In-place Cooley-Tukey radix-2 FFT
  62. * @param data Complex array of size n (must be power of 2)
  63. * @param n Size of array
  64. * @param inverse If true, compute inverse FFT
  65. */
  66. inline void fft_radix2(Complex* data, int n, bool inverse = false) {
  67. bit_reverse(data, n);
  68. // Danielson-Lanczos lemma
  69. for (int len = 2; len <= n; len <<= 1) {
  70. float angle = (inverse ? 2.0f : -2.0f) * static_cast<float>(M_PI) / len;
  71. Complex w_n(std::cos(angle), std::sin(angle));
  72. for (int i = 0; i < n; i += len) {
  73. Complex w(1.0f, 0.0f);
  74. for (int j = 0; j < len / 2; ++j) {
  75. Complex u = data[i + j];
  76. Complex t = w * data[i + j + len / 2];
  77. data[i + j] = u + t;
  78. data[i + j + len / 2] = u - t;
  79. w *= w_n;
  80. }
  81. }
  82. }
  83. // Normalize for inverse FFT
  84. if (inverse) {
  85. for (int i = 0; i < n; ++i) {
  86. data[i] /= static_cast<float>(n);
  87. }
  88. }
  89. }
  90. /**
  91. * Real-to-complex FFT (rfft) matching torch.fft.rfft
  92. * @param input Real input array of size n
  93. * @param output Complex output array of size n/2+1
  94. * @param n Size of input (must be power of 2)
  95. */
  96. inline void rfft(const float* input, Complex* output, int n) {
  97. // Copy to complex buffer
  98. std::vector<Complex> buffer(n);
  99. for (int i = 0; i < n; ++i) {
  100. buffer[i] = Complex(input[i], 0.0f);
  101. }
  102. // Compute full FFT
  103. fft_radix2(buffer.data(), n, false);
  104. // Extract first n/2+1 coefficients (one-sided)
  105. int n_out = n / 2 + 1;
  106. for (int i = 0; i < n_out; ++i) {
  107. output[i] = buffer[i];
  108. }
  109. }
  110. /**
  111. * Complex-to-real inverse FFT (irfft) matching torch.fft.irfft
  112. * @param input Complex input array of size n/2+1
  113. * @param output Real output array of size n
  114. * @param n_out Size of output (must be power of 2)
  115. */
  116. inline void irfft(const Complex* input, float* output, int n_out) {
  117. int n_freq = n_out / 2 + 1;
  118. // Reconstruct full spectrum (conjugate symmetry)
  119. std::vector<Complex> buffer(n_out);
  120. for (int i = 0; i < n_freq; ++i) {
  121. buffer[i] = input[i];
  122. }
  123. for (int i = n_freq; i < n_out; ++i) {
  124. buffer[i] = std::conj(buffer[n_out - i]);
  125. }
  126. // Compute inverse FFT
  127. fft_radix2(buffer.data(), n_out, true);
  128. // Extract real part
  129. for (int i = 0; i < n_out; ++i) {
  130. output[i] = buffer[i].real();
  131. }
  132. }
  133. //=============================================================================
  134. // STFT Implementation
  135. //=============================================================================
  136. /**
  137. * Short-Time Fourier Transform matching torch.stft
  138. *
  139. * @param audio Input audio [n_samples]
  140. * @param n_samples Number of samples
  141. * @param n_fft FFT size
  142. * @param hop_length Hop between frames
  143. * @param win_length Window length
  144. * @param window Window function [win_length]
  145. * @param center If true, pad signal on both sides
  146. * @param output Output complex spectrogram [n_freq, n_frames, 2] (real, imag pairs)
  147. * @param n_frames_out Output parameter: number of frames
  148. */
  149. inline void compute_stft(
  150. const float* audio,
  151. int n_samples,
  152. int n_fft,
  153. int hop_length,
  154. int win_length,
  155. const float* window,
  156. bool center,
  157. float* output,
  158. int* n_frames_out
  159. ) {
  160. // Center padding
  161. int pad_amount = center ? n_fft / 2 : 0;
  162. int padded_len = n_samples + 2 * pad_amount;
  163. std::vector<float> padded(padded_len);
  164. if (center) {
  165. // Reflect padding
  166. // Left pad (reflect)
  167. for (int i = 0; i < pad_amount; ++i) {
  168. int src_idx = pad_amount - i;
  169. if (src_idx >= n_samples) src_idx = n_samples - 1;
  170. padded[i] = audio[src_idx];
  171. }
  172. // Center (copy)
  173. if (n_samples > 0) {
  174. std::memcpy(padded.data() + pad_amount, audio, n_samples * sizeof(float));
  175. }
  176. // Right pad (reflect)
  177. for (int i = 0; i < pad_amount; ++i) {
  178. int src_idx = n_samples - 2 - i;
  179. if (src_idx < 0) src_idx = 0;
  180. padded[pad_amount + n_samples + i] = audio[src_idx];
  181. }
  182. } else {
  183. std::memcpy(padded.data(), audio, n_samples * sizeof(float));
  184. }
  185. // Calculate number of frames
  186. // PyTorch formula: (L - N) / H + 1
  187. int n_frames = 1 + (padded_len - n_fft) / hop_length;
  188. if (n_frames < 0) n_frames = 0;
  189. *n_frames_out = n_frames;
  190. // Number of output frequency bins
  191. int n_freq = n_fft / 2 + 1;
  192. // Prepare padded window if win_length < n_fft
  193. std::vector<float> window_padded(n_fft, 0.0f);
  194. if (win_length < n_fft) {
  195. int left = (n_fft - win_length) / 2;
  196. std::memcpy(window_padded.data() + left, window, win_length * sizeof(float));
  197. } else {
  198. std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
  199. }
  200. // Buffers
  201. std::vector<float> frame(n_fft);
  202. std::vector<Complex> fft_out(n_freq);
  203. // Process each frame
  204. #ifdef USE_OPENMP
  205. #pragma omp parallel for
  206. #endif
  207. for (int f = 0; f < n_frames; ++f) {
  208. int start = f * hop_length;
  209. // Extract and window frame
  210. // Need private buffer for frame and fft_out if logical threads share memory?
  211. // Wait, std::vector inside loop is local to block, so essentially thread-private?
  212. // YES. Variables declared inside the loop are private to the iteration/thread.
  213. // However, we need to be careful about allocating vectors inside a loop in parallel (heap contention).
  214. // It's better to allocate buffers per thread or use raw arrays.
  215. // For simplicity and since n_fft is small (2048), stack array or thread_local vector is better.
  216. // But std::vector inside parallel for is safe but might allocate.
  217. // n_fft=2048 float is 8KB.
  218. std::vector<float> frame(n_fft); // Allocation!
  219. std::vector<Complex> fft_out(n_freq);
  220. for (int i = 0; i < n_fft; ++i) {
  221. frame[i] = padded[start + i] * window_padded[i];
  222. }
  223. // Compute FFT
  224. rfft(frame.data(), fft_out.data(), n_fft);
  225. // Store in output [n_freq, n_frames, 2] format
  226. for (int k = 0; k < n_freq; ++k) {
  227. // Note: Output layout is [Freq, Time, 2]
  228. output[(k * n_frames + f) * 2 + 0] = fft_out[k].real();
  229. output[(k * n_frames + f) * 2 + 1] = fft_out[k].imag();
  230. }
  231. }
  232. }
  233. /**
  234. * Inverse Short-Time Fourier Transform matching torch.istft
  235. *
  236. * @param stft_data Input complex spectrogram [n_freq, n_frames, 2]
  237. * @param n_freq Number of frequency bins
  238. * @param n_frames Number of frames
  239. * @param n_fft FFT size
  240. * @param hop_length Hop between frames
  241. * @param win_length Window length
  242. * @param window Window function [win_length]
  243. * @param center If true, signal was centered
  244. * @param length Expected output length (or 0 for auto)
  245. * @param output Output audio
  246. */
  247. inline void compute_istft(
  248. const float* stft_data,
  249. int n_freq,
  250. int n_frames,
  251. int n_fft,
  252. int hop_length,
  253. int win_length,
  254. const float* window,
  255. bool center,
  256. int length,
  257. float* output
  258. ) {
  259. // Calculate expected output signal length
  260. int expected_len = n_fft + hop_length * (n_frames - 1);
  261. int pad_amount = center ? n_fft / 2 : 0;
  262. int output_len = (length > 0) ? length : (expected_len - 2 * pad_amount);
  263. // Prepare padded window
  264. std::vector<float> window_padded(n_fft, 0.0f);
  265. if (win_length < n_fft) {
  266. int left = (n_fft - win_length) / 2;
  267. std::memcpy(window_padded.data() + left, window, win_length * sizeof(float));
  268. } else {
  269. std::memcpy(window_padded.data(), window, n_fft * sizeof(float));
  270. }
  271. // Overlap-add buffers
  272. // This is tricky for parallelization: race condition on y (overlap-add).
  273. // We CANNOT parallelize the write to 'y' easily without atomic float add (slow/hard) or reduction.
  274. // APPROACH:
  275. // 1. Parallel IFFT: Compute all frames' time-domain signals into a large buffer [n_frames, n_fft].
  276. // 2. Serial Overlap-Add: Add them up. (Overlap-add is O(N_Frames * N_FFT), same complexity, but memory bound).
  277. // Serial part might be fast enough if FFT is the heavy lifter.
  278. // FFT is O(N log N). Overlap add is O(N). FFT dominates.
  279. // Step 1: Compute all IFFTs in parallel
  280. std::vector<float> frames_time_domain(n_frames * n_fft);
  281. #ifdef USE_OPENMP
  282. #pragma omp parallel for
  283. #endif
  284. for (int f = 0; f < n_frames; ++f) {
  285. std::vector<Complex> fft_in(n_freq);
  286. std::vector<float> frame_out(n_fft);
  287. // Extract complex spectrum
  288. for (int k = 0; k < n_freq; ++k) {
  289. float re = stft_data[(k * n_frames + f) * 2 + 0];
  290. float im = stft_data[(k * n_frames + f) * 2 + 1];
  291. fft_in[k] = Complex(re, im);
  292. }
  293. // IFFT
  294. irfft(fft_in.data(), frame_out.data(), n_fft);
  295. // Store
  296. std::memcpy(&frames_time_domain[f * n_fft], frame_out.data(), n_fft * sizeof(float));
  297. }
  298. // Step 2: Overlap Add (Serial)
  299. std::vector<float> y(expected_len, 0.0f);
  300. std::vector<float> window_sum(expected_len, 0.0f);
  301. for (int f = 0; f < n_frames; ++f) {
  302. int start = f * hop_length;
  303. const float* frame_ptr = &frames_time_domain[f * n_fft];
  304. for (int i = 0; i < n_fft; ++i) {
  305. y[start + i] += frame_ptr[i] * window_padded[i];
  306. window_sum[start + i] += window_padded[i] * window_padded[i];
  307. }
  308. }
  309. // Normalize by window sum (avoid division by zero)
  310. for (int i = 0; i < expected_len; ++i) {
  311. if (window_sum[i] > 1e-8f) {
  312. y[i] /= window_sum[i];
  313. }
  314. }
  315. // Remove center padding and copy to output
  316. for (int i = 0; i < output_len; ++i) {
  317. if (pad_amount + i < expected_len) {
  318. output[i] = y[pad_amount + i];
  319. } else {
  320. output[i] = 0.0f;
  321. }
  322. }
  323. }
  324. } // namespace stft