test_component_layers.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <string>
  5. #include <cstring>
  6. #include <cstdlib>
  7. #include <ggml.h>
  8. #include <ggml-alloc.h>
  9. #include <ggml-backend.h>
  10. #include "../src/model.h"
  11. #include "../src/utils.h"
  12. /**
  13. * test_component_layers.cpp
  14. *
  15. * Verifies Transformer layers against golden tensors from export_debug.py
  16. * Copied from tests_old/test_component_layers.cpp with env var support
  17. */
  18. std::string GetModelPath() {
  19. const char* env = std::getenv("MBR_MODEL_PATH");
  20. return env ? env : "mel_band_roformer.gguf";
  21. }
  22. std::string GetTestDataDir() {
  23. const char* env = std::getenv("MBR_TEST_DATA_DIR");
  24. return env ? env : ".";
  25. }
  26. int main(int argc, char* argv[]) {
  27. std::cout << "========================================" << std::endl;
  28. std::cout << "Test: Transformer Layers Verification" << std::endl;
  29. std::cout << "========================================" << std::endl;
  30. std::string model_path = GetModelPath();
  31. std::string debug_dir = GetTestDataDir();
  32. if (argc > 1) model_path = argv[1];
  33. if (argc > 2) debug_dir = argv[2];
  34. try {
  35. // 1. Load Model
  36. std::cout << "\n[1/6] Loading model..." << std::endl;
  37. MelBandRoformer model;
  38. model.Initialize(model_path);
  39. // 2. Load golden tensors
  40. std::cout << "\n[2/6] Loading golden tensors..." << std::endl;
  41. // Load after_band_split (input to Transformers)
  42. auto [input_data, input_shape] = utils::load_activation(debug_dir, "after_band_split");
  43. if (!input_data) {
  44. std::cerr << "Failed to load after_band_split.npy" << std::endl;
  45. return 1;
  46. }
  47. std::cout << " Input (after_band_split) shape: [";
  48. for (size_t i = 0; i < input_shape.size(); ++i) {
  49. std::cout << input_shape[i];
  50. if (i < input_shape.size() - 1) std::cout << ", ";
  51. }
  52. std::cout << "]" << std::endl;
  53. // Load before_mask_est (expected output after all 6 layers)
  54. auto [expected_data, expected_shape] = utils::load_activation(debug_dir, "before_mask_est");
  55. if (!expected_data) {
  56. std::cerr << "Failed to load before_mask_est.npy" << std::endl;
  57. utils::free_npy_data(input_data);
  58. return 1;
  59. }
  60. std::cout << " Expected (before_mask_est) shape: [";
  61. for (size_t i = 0; i < expected_shape.size(); ++i) {
  62. std::cout << expected_shape[i];
  63. if (i < expected_shape.size() - 1) std::cout << ", ";
  64. }
  65. std::cout << "]" << std::endl;
  66. // Extract dimensions from shapes
  67. // PyTorch: [batch, time, bands, dim]
  68. int batch = static_cast<int>(input_shape[0]);
  69. int n_frames = static_cast<int>(input_shape[1]);
  70. int n_bands = static_cast<int>(input_shape[2]);
  71. int dim = static_cast<int>(input_shape[3]);
  72. std::cout << " batch=" << batch << ", n_frames=" << n_frames
  73. << ", n_bands=" << n_bands << ", dim=" << dim << std::endl;
  74. // 3. Build computation graph
  75. std::cout << "\n[3/6] Building computation graph..." << std::endl;
  76. size_t mem_size = 1024 * 1024 * 1024; // 1GB for Transformers
  77. struct ggml_init_params ctx_params = {
  78. /*.mem_size = */ mem_size,
  79. /*.mem_buffer = */ nullptr,
  80. /*.no_alloc = */ true,
  81. };
  82. ggml_context* ctx = ggml_init(ctx_params);
  83. // Expanded position tensors for CUDA RoPE compatibility:
  84. // pos_time_exp: size [T * F * B], repeating [0..T-1] for each F*B batch
  85. // pos_freq_exp: size [F * T * B], repeating [0..F-1] for each T*B batch
  86. int time_exp_size = n_frames * n_bands * batch; // T * F * B
  87. int freq_exp_size = n_bands * n_frames * batch; // F * T * B
  88. std::vector<int32_t> pos_time_exp_data(time_exp_size);
  89. for (int i = 0; i < time_exp_size; ++i) {
  90. pos_time_exp_data[i] = i % n_frames; // Repeat [0..T-1]
  91. }
  92. std::vector<int32_t> pos_freq_exp_data(freq_exp_size);
  93. for (int i = 0; i < freq_exp_size; ++i) {
  94. pos_freq_exp_data[i] = i % n_bands; // Repeat [0..F-1]
  95. }
  96. ggml_cgraph* gf = ggml_new_graph_custom(ctx, 32768, false);
  97. // Create input tensor: [dim, bands, time, batch] (GGML order)
  98. ggml_tensor* input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32,
  99. dim, n_bands, n_frames, batch);
  100. ggml_set_input(input);
  101. // Create expanded position tensors for RoPE
  102. ggml_tensor* pos_time_exp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, time_exp_size);
  103. ggml_set_input(pos_time_exp);
  104. ggml_tensor* pos_freq_exp = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, freq_exp_size);
  105. ggml_set_input(pos_freq_exp);
  106. // Build Transformers graph
  107. ggml_tensor* x = model.BuildTransformersGraph(ctx, input, gf, pos_time_exp, pos_freq_exp, n_frames, batch);
  108. if (!x) {
  109. std::cerr << "FAILED: BuildTransformersGraph returned nullptr" << std::endl;
  110. utils::free_npy_data(input_data);
  111. utils::free_npy_data(expected_data);
  112. ggml_free(ctx);
  113. return 1;
  114. }
  115. // Mark output
  116. ggml_tensor* output = ggml_dup(ctx, x);
  117. ggml_set_output(output);
  118. ggml_build_forward_expand(gf, output);
  119. std::cout << " Graph built with " << ggml_graph_n_nodes(gf) << " nodes" << std::endl;
  120. // 4. Allocate and execute
  121. std::cout << "\n[4/6] Allocating graph..." << std::endl;
  122. ggml_gallocr_t allocr = ggml_gallocr_new(
  123. ggml_backend_get_default_buffer_type(model.GetBackend())
  124. );
  125. if (!ggml_gallocr_alloc_graph(allocr, gf)) {
  126. std::cerr << "FAILED: Failed to allocate graph" << std::endl;
  127. utils::free_npy_data(input_data);
  128. utils::free_npy_data(expected_data);
  129. ggml_gallocr_free(allocr);
  130. ggml_free(ctx);
  131. return 1;
  132. }
  133. std::cout << "\n[5/6] Executing graph..." << std::endl;
  134. // Copy input data
  135. ggml_backend_tensor_set(input, input_data, 0, ggml_nbytes(input));
  136. // Copy expanded position tensors
  137. ggml_backend_tensor_set(pos_time_exp, pos_time_exp_data.data(), 0, ggml_nbytes(pos_time_exp));
  138. ggml_backend_tensor_set(pos_freq_exp, pos_freq_exp_data.data(), 0, ggml_nbytes(pos_freq_exp));
  139. // Compute
  140. ggml_backend_graph_compute(model.GetBackend(), gf);
  141. // 5. Compare results
  142. std::cout << "\n[6/6] Comparing results..." << std::endl;
  143. // Copy output from GPU to CPU for comparison
  144. std::vector<float> output_data(ggml_nelements(output));
  145. ggml_backend_tensor_get(output, output_data.data(), 0, ggml_nbytes(output));
  146. // Compare element counts
  147. size_t expected_nelements = utils::shape_nelements(expected_shape);
  148. std::cout << " Output elements: " << output_data.size() << std::endl;
  149. std::cout << " Expected elements: " << expected_nelements << std::endl;
  150. // Compute comparison statistics directly
  151. float max_abs = 0.0f;
  152. float sum_abs = 0.0f;
  153. for (size_t i = 0; i < output_data.size() && i < expected_nelements; ++i) {
  154. float diff = std::abs(expected_data[i] - output_data[i]);
  155. max_abs = std::max(max_abs, diff);
  156. sum_abs += diff;
  157. }
  158. float mean_abs = sum_abs / output_data.size();
  159. std::cout << "\n[Comparison] Transformers Output" << std::endl;
  160. std::cout << " Max abs diff: " << max_abs << std::endl;
  161. std::cout << " Mean abs diff: " << mean_abs << std::endl;
  162. bool match = max_abs <= 3e-2f || mean_abs <= 3e-3f;
  163. // Cleanup
  164. utils::free_npy_data(input_data);
  165. utils::free_npy_data(expected_data);
  166. ggml_gallocr_free(allocr);
  167. ggml_free(ctx);
  168. if (match) {
  169. std::cout << "\nPASSED: Transformers match PyTorch output" << std::endl;
  170. return 0;
  171. } else {
  172. std::cout << "\nFAILED: Transformers do not match PyTorch output" << std::endl;
  173. return 1;
  174. }
  175. } catch (const std::exception& e) {
  176. std::cerr << "Error: " << e.what() << std::endl;
  177. return 1;
  178. }
  179. }