test_inference.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <string>
  5. #include <cstdlib>
  6. #include <algorithm>
  7. #include "bs_roformer/inference.h"
  8. #include "../src/utils.h"
  9. /**
  10. * test_inference.cpp
  11. *
  12. * Verifies Inference class against golden tensors from export_debug.py
  13. * Copied from tests_old/test_inference.cpp with env var support
  14. */
  15. std::string GetModelPath() {
  16. const char* env = std::getenv("BSR_MODEL_PATH");
  17. return env ? env : "bs_roformer.gguf";
  18. }
  19. std::string GetTestDataDir() {
  20. const char* env = std::getenv("BSR_TEST_DATA_DIR");
  21. return env ? env : ".";
  22. }
  23. int main(int argc, char* argv[]) {
  24. std::cout << "========================================" << std::endl;
  25. std::cout << "Test: Inference Class Verification" << std::endl;
  26. std::cout << "========================================" << std::endl;
  27. std::string model_path = GetModelPath();
  28. std::string debug_dir = GetTestDataDir();
  29. if (argc > 1) model_path = argv[1];
  30. if (argc > 2) debug_dir = argv[2];
  31. try {
  32. // 1. Initialize Inference
  33. std::cout << "\n[1/3] Initializing Inference Engine..." << std::endl;
  34. Inference engine(model_path);
  35. // 2. Load Input Audio
  36. std::cout << "\n[2/3] Loading Input Audio..." << std::endl;
  37. auto [input_audio_ptr, input_audio_shape] = utils::load_activation(debug_dir, "input_audio");
  38. if (!input_audio_ptr) return 1;
  39. // Convert to vector (input_audio.npy is [batch, channels, samples] interleaved)
  40. // input_audio_shape: [1, 2, 132300]
  41. size_t total_samples = input_audio_shape[0] * input_audio_shape[1] * input_audio_shape[2];
  42. std::vector<float> input_audio(input_audio_ptr, input_audio_ptr + total_samples);
  43. // 3. Process
  44. std::cout << "\n[3/3] Processing Audio..." << std::endl;
  45. // Use ProcessChunk to verify raw model output without Overlap-Add windowing/padding
  46. // This matches the generation of output_audio.npy
  47. std::vector<std::vector<float>> output_stems = engine.ProcessChunk(input_audio);
  48. std::vector<float> output_audio = output_stems[0];
  49. // Smoke test new cancel callback path in Process()
  50. size_t smoke_samples = std::min<size_t>(input_audio.size(), static_cast<size_t>(16384 * 2));
  51. if (smoke_samples % 2 != 0) {
  52. smoke_samples -= 1;
  53. }
  54. if (smoke_samples >= 2) {
  55. std::vector<float> smoke_input(input_audio.begin(), input_audio.begin() + smoke_samples);
  56. auto cancel_false = []() { return false; };
  57. auto smoke_stems = engine.Process(smoke_input, 16384, 2, nullptr, cancel_false);
  58. if (smoke_stems.empty() || smoke_stems[0].empty()) {
  59. std::cerr << "Process() smoke test returned empty output" << std::endl;
  60. return 1;
  61. }
  62. }
  63. std::cout << " Input size: " << input_audio.size() << std::endl;
  64. std::cout << " Output size: " << output_audio.size() << std::endl;
  65. // Verify against output_audio.npy
  66. std::cout << "\n[Verification] Comparing against golden output..." << std::endl;
  67. auto [expected_output, expected_shape] = utils::load_activation(debug_dir, "output_audio");
  68. if (!expected_output) {
  69. std::cerr << "Golden output not found" << std::endl;
  70. return 1;
  71. }
  72. // expected_output: [batch=1, channels=2, samples=132300] (Planar/C-contiguous)
  73. // output_audio: interleaved [ch0, ch1, ch0, ch1...]
  74. int channels = 2;
  75. int samples = input_audio_shape[2]; // 132300
  76. float max_diff = 0.0f;
  77. float sum_diff = 0.0f;
  78. int valid_samples = 0;
  79. for (int i = 0; i < samples; ++i) {
  80. for (int ch = 0; ch < channels; ++ch) {
  81. // Expected: ch * samples + i
  82. float expected = expected_output[ch * samples + i];
  83. // Actual: i * channels + ch
  84. if (i * channels + ch >= output_audio.size()) continue;
  85. float actual = output_audio[i * channels + ch];
  86. float diff = std::abs(expected - actual);
  87. max_diff = std::max(max_diff, diff);
  88. sum_diff += diff;
  89. valid_samples++;
  90. }
  91. }
  92. if (valid_samples == 0) valid_samples = 1;
  93. std::cout << " Max abs diff: " << max_diff << std::endl;
  94. std::cout << " Mean abs diff: " << (sum_diff / valid_samples) << std::endl;
  95. bool pass = (sum_diff / valid_samples) < 0.1f;
  96. if (pass) std::cout << "PASSED" << std::endl;
  97. else std::cout << "FAILED" << std::endl;
  98. utils::free_npy_data(input_audio_ptr);
  99. utils::free_npy_data(expected_output);
  100. return pass ? 0 : 1;
  101. } catch (const std::exception& e) {
  102. std::cerr << "Error: " << e.what() << std::endl;
  103. return 1;
  104. }
  105. }