test_chunking_logic.cpp 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #include "test_common.h"
  2. #include "bs_roformer/inference.h"
  3. #include <cstring>
  4. // We need to test the static helper OR the pipeline.
  5. // Check if Inference::ProcessOverlapAdd is still available and public.
  6. // Yes, it is in inference.h and implemented in inference.cpp.
  7. int main(int argc, char* argv[]) {
  8. std::cout << "Test: Chunking Logic (Overlap-Add) Verification" << std::endl;
  9. std::string data_dir = GetTestDataDir();
  10. // Use files generated by export_chunking_debug.py if available, or skip
  11. // If we included them in package_test_data.py, they might be in activations/ ??
  12. // No, export_chunking_debug.py puts them in root or specified dir.
  13. // If packaged, we might have them?
  14. // Let's assume they are in data_dir (which might be "golden" root).
  15. // Note: chunk_in.npy and chunk_out.npy are NOT in 'activations/' subdir normally
  16. // but in tests/ or root.
  17. // Let's try loading from data_dir directly.
  18. // Fallback: If not found, try generating? Or just skip?
  19. // Better: Assume they are present.
  20. // We use load_npy directly as they might not be in activations/
  21. // We use load_npy directly as they are in tests/ directory
  22. // Use data_dir (from BSR_TEST_DATA_DIR or default)
  23. std::string in_path = data_dir + "/chunk_in.npy";
  24. std::string out_path = data_dir + "/chunk_out.npy";
  25. if (argc > 1) in_path = argv[1];
  26. if (argc > 2) out_path = argv[2];
  27. auto [in_ptr, in_shape] = utils::load_npy(in_path);
  28. if (!in_ptr) {
  29. // Try checking if it's in the 'activations' subdir (legacy/alternative structure)
  30. std::string alt_in = data_dir + "/activations/chunk_in.npy";
  31. auto res = utils::load_npy(alt_in);
  32. if (res.first) {
  33. in_ptr = res.first; in_shape = res.second;
  34. in_path = alt_in;
  35. }
  36. }
  37. if (!in_ptr) {
  38. // Just print absolute path hint for debugging
  39. std::cout << "[SKIP] chunk_in.npy not found in " << data_dir << " or " << in_path << std::endl;
  40. return 0;
  41. }
  42. auto [out_ptr, out_shape] = utils::load_npy(out_path);
  43. if (!out_ptr) {
  44. std::string alt_out = data_dir + "/activations/chunk_out.npy";
  45. auto res = utils::load_npy(alt_out);
  46. if (res.first) {
  47. out_ptr = res.first; out_shape = res.second;
  48. }
  49. }
  50. if (!out_ptr) {
  51. std::cout << "[SKIP] chunk_out.npy not found" << std::endl;
  52. utils::free_npy_data(in_ptr);
  53. return 0;
  54. }
  55. std::vector<float> input_vec(utils::shape_nelements(in_shape));
  56. std::memcpy(input_vec.data(), in_ptr, input_vec.size()*sizeof(float));
  57. utils::free_npy_data(in_ptr);
  58. // Expected
  59. std::vector<float> expected_vec(utils::shape_nelements(out_shape));
  60. std::memcpy(expected_vec.data(), out_ptr, expected_vec.size()*sizeof(float));
  61. utils::free_npy_data(out_ptr);
  62. // Run Logic
  63. int chunk_size = 352800;
  64. int num_overlap = 2;
  65. std::cout << " Input size: " << input_vec.size() << std::endl;
  66. // Identity Model
  67. auto identity = [](const std::vector<float>& chunk) { return std::vector<std::vector<float>>{chunk}; };
  68. // We test the STATIC legacy method because we can't easily mock the pipeline
  69. // inside Inference class without refactoring it to accept an abstract Model interface.
  70. auto actual_stems = Inference::ProcessOverlapAdd(input_vec, chunk_size, num_overlap, identity);
  71. std::vector<float> actual = actual_stems[0];
  72. bool pass = CompareAndReport("OverlapAdd Logic",
  73. expected_vec.data(), expected_vec.size(),
  74. actual.data(), actual.size(),
  75. 1e-4f, 1e-4f);
  76. if (pass) LOG_PASS(); else LOG_FAIL();
  77. return pass ? 0 : 1;
  78. }