test_inference.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <string>
  5. #include <cstdlib>
  6. #include "mel_band_roformer/inference.h"
  7. #include "../src/utils.h"
  8. /**
  9. * test_inference.cpp
  10. *
  11. * Verifies Inference class against golden tensors from export_debug.py
  12. * Copied from tests_old/test_inference.cpp with env var support
  13. */
  14. std::string GetModelPath() {
  15. const char* env = std::getenv("MBR_MODEL_PATH");
  16. return env ? env : "mel_band_roformer.gguf";
  17. }
  18. std::string GetTestDataDir() {
  19. const char* env = std::getenv("MBR_TEST_DATA_DIR");
  20. return env ? env : ".";
  21. }
  22. int main(int argc, char* argv[]) {
  23. std::cout << "========================================" << std::endl;
  24. std::cout << "Test: Inference Class Verification" << std::endl;
  25. std::cout << "========================================" << std::endl;
  26. std::string model_path = GetModelPath();
  27. std::string debug_dir = GetTestDataDir();
  28. if (argc > 1) model_path = argv[1];
  29. if (argc > 2) debug_dir = argv[2];
  30. try {
  31. // 1. Initialize Inference
  32. std::cout << "\n[1/3] Initializing Inference Engine..." << std::endl;
  33. Inference engine(model_path);
  34. // 2. Load Input Audio
  35. std::cout << "\n[2/3] Loading Input Audio..." << std::endl;
  36. auto [input_audio_ptr, input_audio_shape] = utils::load_activation(debug_dir, "input_audio");
  37. if (!input_audio_ptr) return 1;
  38. // Convert to vector (input_audio.npy is [batch, channels, samples] interleaved)
  39. // input_audio_shape: [1, 2, 132300]
  40. size_t total_samples = input_audio_shape[0] * input_audio_shape[1] * input_audio_shape[2];
  41. std::vector<float> input_audio(input_audio_ptr, input_audio_ptr + total_samples);
  42. // 3. Process
  43. std::cout << "\n[3/3] Processing Audio..." << std::endl;
  44. // Use ProcessChunk to verify raw model output without Overlap-Add windowing/padding
  45. // This matches the generation of output_audio.npy
  46. std::vector<std::vector<float>> output_stems = engine.ProcessChunk(input_audio);
  47. std::vector<float> output_audio = output_stems[0];
  48. std::cout << " Input size: " << input_audio.size() << std::endl;
  49. std::cout << " Output size: " << output_audio.size() << std::endl;
  50. // Verify against output_audio.npy
  51. std::cout << "\n[Verification] Comparing against golden output..." << std::endl;
  52. auto [expected_output, expected_shape] = utils::load_activation(debug_dir, "output_audio");
  53. if (!expected_output) {
  54. std::cerr << "Golden output not found" << std::endl;
  55. return 1;
  56. }
  57. // expected_output: [batch=1, channels=2, samples=132300] (Planar/C-contiguous)
  58. // output_audio: interleaved [ch0, ch1, ch0, ch1...]
  59. int channels = 2;
  60. int samples = input_audio_shape[2]; // 132300
  61. float max_diff = 0.0f;
  62. float sum_diff = 0.0f;
  63. int valid_samples = 0;
  64. for (int i = 0; i < samples; ++i) {
  65. for (int ch = 0; ch < channels; ++ch) {
  66. // Expected: ch * samples + i
  67. float expected = expected_output[ch * samples + i];
  68. // Actual: i * channels + ch
  69. if (i * channels + ch >= output_audio.size()) continue;
  70. float actual = output_audio[i * channels + ch];
  71. float diff = std::abs(expected - actual);
  72. max_diff = std::max(max_diff, diff);
  73. sum_diff += diff;
  74. valid_samples++;
  75. }
  76. }
  77. if (valid_samples == 0) valid_samples = 1;
  78. std::cout << " Max abs diff: " << max_diff << std::endl;
  79. std::cout << " Mean abs diff: " << (sum_diff / valid_samples) << std::endl;
  80. bool pass = (sum_diff / valid_samples) < 0.1f;
  81. if (pass) std::cout << "PASSED" << std::endl;
  82. else std::cout << "FAILED" << std::endl;
  83. utils::free_npy_data(input_audio_ptr);
  84. utils::free_npy_data(expected_output);
  85. return pass ? 0 : 1;
  86. } catch (const std::exception& e) {
  87. std::cerr << "Error: " << e.what() << std::endl;
  88. return 1;
  89. }
  90. }