test_component_bandsplit.cpp 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #include "test_common.h"
  2. int main(int argc, char* argv[]) {
  3. std::cout << "Test: BandSplit Component Verification" << std::endl;
  4. // 1. 获取资源
  5. std::string model_path = GetModelPath();
  6. std::string data_dir = GetTestDataDir();
  7. if (argc > 1) model_path = argv[1];
  8. if (argc > 2) data_dir = argv[2];
  9. LOG_STEP(1, 4, "Loading model from " + model_path);
  10. MelBandRoformer model;
  11. model.Initialize(model_path);
  12. LOG_STEP(2, 4, "Loading golden tensors from " + data_dir);
  13. GoldenTensor input(data_dir, "band_split_in");
  14. GoldenTensor expected(data_dir, "after_band_split");
  15. TEST_ASSERT_LOAD(input, "band_split_in");
  16. TEST_ASSERT_LOAD(expected, "after_band_split");
  17. input.PrintShape("Input");
  18. expected.PrintShape("Expected");
  19. // PyTorch [batch, bands, time, dim] -> GGML [dim, time, bands, batch] ?
  20. // Wait, utils.cpp says: load_npy returns raw data and shape.
  21. // PyTorch input: [batch, bands, time, dim]
  22. // GGML expected Input: [dim, bands, time, batch] ? No.
  23. // Let's check original test...
  24. // Original: total_dim_input(idx=2), n_frames(idx=1), batch(idx=0).
  25. // Original input: [batch, frames, dim] ??
  26. // band_split_in.npy shape from original output: [1, 301, 384] (Batch, Time, Dim)?
  27. // No, let's look at export_debug.py line 219: `x = rearrange(x, 'b t (f c) -> b t f c')` ??
  28. // Wait, export_debug.py:
  29. // x = stft_repr[batch_arange, freq_indices] -> [b, f, t, c]
  30. // x = rearrange(x, 'b f t c -> b t (f c)') -> [b, t, features]
  31. // So 'band_split_in' is [Batch, Time, Features]
  32. // GGML Tensor likely: [Features, Time, Batch] (Transposed for column-major/GGML)
  33. int batch = input.shape[0];
  34. int n_frames = input.shape[1];
  35. int total_dim = input.shape[2];
  36. // 3. Build Graph
  37. LOG_STEP(3, 4, "Building computation graph");
  38. TestContext tc(&model);
  39. // GGML Tensor shape: [dim, n_frames, batch]
  40. ggml_tensor* in_tensor = ggml_new_tensor_3d(tc.ctx, GGML_TYPE_F32, total_dim, n_frames, batch);
  41. ggml_set_input(in_tensor);
  42. ggml_tensor* out = model.BuildBandSplitGraph(tc.ctx, in_tensor, tc.gf, n_frames, batch);
  43. TEST_ASSERT(out, "BuildBandSplitGraph returned nullptr");
  44. // Mark output for computation
  45. ggml_build_forward_expand(tc.gf, out);
  46. // 4. Exec
  47. LOG_STEP(4, 4, "Executing");
  48. if (!tc.AllocateGraph()) {
  49. std::cerr << "Graph allocation failed" << std::endl;
  50. return 1;
  51. }
  52. // Copy input (NumPy [B, T, D] -> GGML [D, T, B])
  53. // The memory layout of NumPy [B,T,D] (C-contiguous) is:
  54. // Batch 0 -> Time 0 -> Dim 0..D
  55. // GGML [D, T, B] (F-contiguous-ish, but tensor struct is different)
  56. // Actually GGML default tensor is [ne0, ne1, ne2, ne3]
  57. // ne0 is fastest moving dimension.
  58. // If we say tensor is [D, T, B], ne0=D, ne1=T, ne2=B.
  59. // So data layout is D contiguous, then T, then B.
  60. // This MATCHES NumPy [B, T, D] C-contiguous!
  61. // NumPy: fast index is last dim (D).
  62. // GGML: fast index is first dim (ne0=D).
  63. // So we can memcpy directly!
  64. ggml_backend_tensor_set(in_tensor, input.data, 0, ggml_nbytes(in_tensor));
  65. tc.Compute();
  66. // 5. Compare
  67. auto output = tc.ReadTensor(out);
  68. bool pass = CompareAndReport("BandSplit",
  69. expected.data, expected.nelements(),
  70. output.data(), output.size());
  71. if (pass) LOG_PASS(); else LOG_FAIL();
  72. return pass ? 0 : 1;
  73. }