test_component_mask.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #include "test_common.h"
  2. int main(int argc, char* argv[]) {
  3. std::cout << "Test: MaskEstimator Component Verification" << std::endl;
  4. std::string model_path = GetModelPath();
  5. std::string data_dir = GetTestDataDir();
  6. if (argc > 1) model_path = argv[1];
  7. if (argc > 2) data_dir = argv[2];
  8. LOG_STEP(1, 4, "Loading model from " + model_path);
  9. BSRoformer model;
  10. model.Initialize(model_path);
  11. LOG_STEP(2, 4, "Loading golden tensors");
  12. GoldenTensor input(data_dir, "before_mask_est");
  13. GoldenTensor expected(data_dir, "mask_est0");
  14. TEST_ASSERT_LOAD(input, "before_mask_est");
  15. TEST_ASSERT_LOAD(expected, "mask_est0");
  16. input.PrintShape("Input");
  17. expected.PrintShape("Expected");
  18. // Input PyTorch: [1, T, Bands, Dim] -> [1, 301, 60, 64] ?
  19. // Let's check export_debug.py line 246
  20. // x (before_mask_est) comes from freq_transformer.
  21. // x shape is [batch, time, bands, dim] (rearranged in line 229: b t f d)
  22. // Wait, line 229 says: x = rearrange(x, 'b f t d -> b t f d')
  23. // So input is [B, T, Bands, Dim]
  24. int batch = input.shape[0];
  25. int n_frames = input.shape[1];
  26. int n_bands = input.shape[2];
  27. int dim = input.shape[3];
  28. // 3. Build Graph
  29. LOG_STEP(3, 4, "Building computation graph");
  30. TestContext tc(&model);
  31. // GGML Input: [Dim, Bands, Frames, Batch] (ne0=Dim)
  32. // Matches NumPy [B, T, Bands, Dim] layout directly
  33. ggml_tensor* in_tensor = ggml_new_tensor_4d(tc.ctx, GGML_TYPE_F32, dim, n_bands, n_frames, batch);
  34. ggml_set_input(in_tensor);
  35. ggml_tensor* out = model.BuildMaskEstimatorGraph(tc.ctx, in_tensor, tc.gf, n_frames, batch);
  36. TEST_ASSERT(out, "BuildMaskEstimatorGraph returned nullptr");
  37. ggml_build_forward_expand(tc.gf, out);
  38. // 4. Exec
  39. LOG_STEP(4, 4, "Executing");
  40. if (!tc.AllocateGraph()) return 1;
  41. ggml_backend_tensor_set(in_tensor, input.data, 0, ggml_nbytes(in_tensor));
  42. tc.Compute();
  43. // 5. Compare
  44. auto output = tc.ReadTensor(out);
  45. // For multi-stem models (like Deux with 2 stems), the output will contain all stems.
  46. // mask_est0.npy likely only contains the first stem (or the target stem).
  47. // If output size > expected size, we should compare only the matching portion (first stem).
  48. size_t expected_size = expected.nelements();
  49. size_t actual_size = output.size();
  50. bool pass = false;
  51. if (actual_size > expected_size && actual_size % expected_size == 0) {
  52. // De-interleave Stem 0
  53. // Data layout: [Freqs, Stems, Frames, Batch] (ne0, ne1, ne2, ne3)
  54. // Stride per frame = Freqs * Stems
  55. // We want Stem 0 for each frame.
  56. std::vector<float> stem0_output;
  57. stem0_output.reserve(expected_size);
  58. int num_stems = (int)(actual_size / expected_size);
  59. int n_frames = (int)input.shape[1]; // Known from input
  60. int n_freqs = (int)(expected_size / n_frames); // Inferred Freqs per frame
  61. std::cout << "Detected multi-stem output (" << num_stems << " stems). Verifying Stem 0..." << std::endl;
  62. // Verify assumption
  63. if ((size_t)(num_stems * n_freqs * n_frames) != actual_size) {
  64. std::cerr << "Warning: Shape mismatch calculation in verification logic." << std::endl;
  65. }
  66. for (int t = 0; t < n_frames; ++t) {
  67. size_t frame_start = t * (n_freqs * num_stems);
  68. size_t stem0_start = frame_start; // Stem 0 is at offset 0 in the stride
  69. // Copy n_freqs elements
  70. for (int f = 0; f < n_freqs; ++f) {
  71. if (stem0_start + f < output.size()) {
  72. stem0_output.push_back(output[stem0_start + f]);
  73. }
  74. }
  75. }
  76. pass = CompareAndReport("MaskEstimator (Stem 0)",
  77. expected.data, expected_size,
  78. stem0_output.data(), stem0_output.size());
  79. } else {
  80. pass = CompareAndReport("MaskEstimator",
  81. expected.data, expected.nelements(),
  82. output.data(), output.size());
  83. }
  84. if (pass) LOG_PASS(); else LOG_FAIL();
  85. return pass ? 0 : 1;
  86. }