test_stft_consistency.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include "test_common.h"
  2. #include "../src/stft.h"
  3. #include "../src/model.h"
  4. int main(int argc, char** argv) {
  5. std::cout << "Test: STFT/ISTFT Consistency with PyTorch" << std::endl;
  6. // 1. Load Model to get parameters
  7. std::string model_path = GetModelPath();
  8. std::cout << "Loading model params from: " << model_path << std::endl;
  9. // We only need the model to read parameters (n_fft, etc.) from GGUF
  10. // We don't need to allocate the full graph or weights.
  11. BSRoformer model;
  12. try {
  13. model.Initialize(model_path);
  14. } catch (const std::exception& e) {
  15. std::cerr << "Failed to load model: " << e.what() << std::endl;
  16. std::cerr << "Ensure MBR_MODEL_PATH is set correctly or bs_roformer.gguf exists." << std::endl;
  17. return 1;
  18. }
  19. int n_fft = model.GetNFFT();
  20. int hop_length = model.GetHopLength();
  21. int win_length = model.GetWinLength();
  22. std::cout << "STFT Params: n_fft=" << n_fft << ", hop_length=" << hop_length << ", win_length=" << win_length << std::endl;
  23. // 2. Load Data
  24. std::string data_dir = GetTestDataDir();
  25. std::cout << "Loading test data from: " << data_dir << std::endl;
  26. GoldenTensor input_audio(data_dir, "input_audio"); // [batch, channels, samples]
  27. GoldenTensor expected_stft(data_dir, "stft_raw"); // [batch, channels, freq, time, 2]
  28. GoldenTensor expected_istft(data_dir, "istft_raw"); // [batch, channels, samples]
  29. TEST_ASSERT_LOAD(input_audio, "input_audio");
  30. TEST_ASSERT_LOAD(expected_stft, "stft_raw");
  31. TEST_ASSERT_LOAD(expected_istft, "istft_raw");
  32. input_audio.PrintShape("Input Audio");
  33. expected_stft.PrintShape("Expected STFT");
  34. expected_istft.PrintShape("Expected ISTFT");
  35. int batch = input_audio.shape[0];
  36. int channels = input_audio.shape[1];
  37. int n_samples = input_audio.shape[2];
  38. int n_freq = n_fft / 2 + 1;
  39. int expected_n_frames = expected_stft.shape[3];
  40. // 3. Prepare Window
  41. std::vector<float> window(win_length);
  42. stft::hann_window(window.data(), win_length);
  43. bool all_passed = true;
  44. // 4. Test STFT
  45. std::cout << "\n=== Testing STFT ===" << std::endl;
  46. for (int b = 0; b < batch; ++b) {
  47. for (int c = 0; c < channels; ++c) {
  48. // Extract input channel
  49. std::vector<float> in_channel(n_samples);
  50. for (int i = 0; i < n_samples; ++i) {
  51. // Determine index based on memory layout
  52. // input_audio.npy is F-contiguous [1, 2, 220500] => [220500, 2] in memory (interleaved)
  53. // Layout: L0, R0, L1, R1, ...
  54. // Index = (sample_idx * channels + channel_idx)
  55. size_t idx = ((size_t)b * n_samples + i) * channels + c;
  56. in_channel[i] = input_audio.data[idx];
  57. }
  58. // Diagnostic: print first few input values
  59. std::cout << " Input[" << b << "," << c << "] first 5: ";
  60. for (int i = 0; i < 5; ++i) std::cout << in_channel[i] << " ";
  61. std::cout << std::endl;
  62. int n_frames_calc = 0;
  63. // Buffer for output.
  64. // C++ output is [n_freq, n_frames, 2]
  65. std::vector<float> out_stft(n_freq * (expected_n_frames + 10) * 2);
  66. stft::compute_stft(
  67. in_channel.data(), n_samples, n_fft, hop_length, win_length,
  68. window.data(), true, out_stft.data(), &n_frames_calc
  69. );
  70. if (n_frames_calc != expected_n_frames) {
  71. std::cerr << " [Batch " << b << " Ch " << c << "] Frame mismatch: calc=" << n_frames_calc << ", expected=" << expected_n_frames << std::endl;
  72. all_passed = false;
  73. continue;
  74. }
  75. // Compare
  76. size_t channel_stft_size = n_freq * expected_n_frames * 2;
  77. size_t offset = b * channels * channel_stft_size + c * channel_stft_size;
  78. std::string name = "STFT_B" + std::to_string(b) + "_Ch" + std::to_string(c);
  79. if (!CompareAndReport(name,
  80. expected_stft.data + offset, channel_stft_size,
  81. out_stft.data(), channel_stft_size, 1e-3f, 1e-2f)) {
  82. all_passed = false;
  83. }
  84. }
  85. }
  86. // 5. Test ISTFT
  87. std::cout << "\n=== Testing ISTFT ===" << std::endl;
  88. for (int b = 0; b < batch; ++b) {
  89. for (int c = 0; c < channels; ++c) {
  90. size_t channel_stft_size = n_freq * expected_n_frames * 2;
  91. size_t offset = b * channels * channel_stft_size + c * channel_stft_size;
  92. // Input: expected_stft.data + offset
  93. std::vector<float> out_audio(n_samples + n_fft); // Buffer slightly larger
  94. // We pass n_samples as expected length
  95. stft::compute_istft(
  96. expected_stft.data + offset,
  97. n_freq, expected_n_frames, n_fft, hop_length, win_length,
  98. window.data(), true, n_samples, out_audio.data()
  99. );
  100. // Verify against expected_istft
  101. size_t audio_offset = b * channels * n_samples + c * n_samples;
  102. std::string name = "ISTFT_B" + std::to_string(b) + "_Ch" + std::to_string(c);
  103. if (!CompareAndReport(name,
  104. expected_istft.data + audio_offset, n_samples,
  105. out_audio.data(), n_samples, 1e-4f, 1e-3f)) {
  106. all_passed = false;
  107. }
  108. }
  109. }
  110. if (all_passed) {
  111. LOG_PASS();
  112. return 0;
  113. } else {
  114. LOG_FAIL();
  115. return 1;
  116. }
  117. }