test_component_stft.cpp 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <iomanip>
  5. #include "../src/stft.h"
  6. int main() {
  7. std::cout << "Test: Component STFT/ISTFT" << std::endl;
  8. // Parameters
  9. const int sample_rate = 44100;
  10. const int n_fft = 2048;
  11. const int hop_length = 441;
  12. const int win_length = 2048;
  13. const int n_freq = n_fft / 2 + 1;
  14. const int n_samples = 44100 * 2; // 2 seconds
  15. // 1. Generate Signal (Sine wave mixture)
  16. std::vector<float> input(n_samples);
  17. for (int i = 0; i < n_samples; ++i) {
  18. float t = static_cast<float>(i) / sample_rate;
  19. input[i] = std::sin(2.0f * M_PI * 440.0f * t) +
  20. 0.5f * std::sin(2.0f * M_PI * 880.0f * t);
  21. }
  22. // 2. Generate Window
  23. std::vector<float> window(win_length);
  24. stft::hann_window(window.data(), win_length);
  25. // 3. Compute STFT
  26. int n_frames = 0;
  27. // Estimate size: n_freq * estimated_frames * 2, give some buffer
  28. std::vector<float> stft_out(n_freq * 500 * 2);
  29. stft::compute_stft(
  30. input.data(),
  31. n_samples,
  32. n_fft,
  33. hop_length,
  34. win_length,
  35. window.data(),
  36. true, // center
  37. stft_out.data(),
  38. &n_frames
  39. );
  40. std::cout << "STFT Computed: " << n_frames << " frames" << std::endl;
  41. if (n_frames == 0) {
  42. std::cerr << "Failed: 0 frames" << std::endl;
  43. return 1;
  44. }
  45. // 4. Compute ISTFT
  46. std::vector<float> output(n_samples);
  47. stft::compute_istft(
  48. stft_out.data(),
  49. n_freq,
  50. n_frames,
  51. n_fft,
  52. hop_length,
  53. win_length,
  54. window.data(),
  55. true, // center
  56. n_samples,
  57. output.data()
  58. );
  59. // 5. Verify Reconstruction (MSE/MAE)
  60. float max_diff = 0.0f;
  61. float mae = 0.0f;
  62. for (int i = 0; i < n_samples; ++i) {
  63. float diff = std::abs(input[i] - output[i]);
  64. if (diff > max_diff) max_diff = diff;
  65. mae += diff;
  66. }
  67. mae /= n_samples;
  68. std::cout << "Reconstruction Error:" << std::endl;
  69. std::cout << " Max Diff: " << max_diff << std::endl;
  70. std::cout << " MAE: " << mae << std::endl;
  71. // STFT/ISTFT with Hann window and overlap >= 50% should be near perfect
  72. // COLA constraint check: 2048/441 = ~4.6 overlaps, excellent.
  73. if (max_diff > 1e-4) {
  74. std::cerr << "FAILED: Reconstruction error too high (> 1e-4)" << std::endl;
  75. return 1;
  76. }
  77. std::cout << "PASSED" << std::endl;
  78. return 0;
  79. }