test_common.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include "test_common.h"
  2. //======================================================
  3. // TestContext
  4. //======================================================
  5. TestContext::TestContext(MelBandRoformer* m, size_t mem_size) : model(m) {
  6. if (!model) {
  7. std::cerr << "FATAL: Model is null in TestContext" << std::endl;
  8. exit(1);
  9. }
  10. struct ggml_init_params ctx_params = {
  11. /*.mem_size = */ mem_size,
  12. /*.mem_buffer = */ nullptr,
  13. /*.no_alloc = */ true,
  14. };
  15. ctx = ggml_init(ctx_params);
  16. gf = ggml_new_graph_custom(ctx, 16384, false); // Sufficiently large graph
  17. }
  18. TestContext::~TestContext() {
  19. if (allocr) ggml_gallocr_free(allocr);
  20. if (ctx) ggml_free(ctx);
  21. }
  22. bool TestContext::AllocateGraph() {
  23. if (!allocr) {
  24. allocr = ggml_gallocr_new(
  25. ggml_backend_get_default_buffer_type(model->GetBackend())
  26. );
  27. }
  28. return ggml_gallocr_alloc_graph(allocr, gf);
  29. }
  30. void TestContext::Compute() {
  31. ggml_backend_graph_compute(model->GetBackend(), gf);
  32. }
  33. std::vector<float> TestContext::ReadTensor(ggml_tensor* t) {
  34. size_t nelements = ggml_nelements(t);
  35. std::vector<float> buffer(nelements);
  36. ggml_backend_tensor_get(t, buffer.data(), 0, ggml_nbytes(t));
  37. return buffer;
  38. }
  39. //======================================================
  40. // GoldenTensor
  41. //======================================================
  42. GoldenTensor::GoldenTensor(const std::string& dir, const std::string& n) : name(n) {
  43. std::pair<float*, std::vector<size_t>> res = utils::load_activation(dir, name);
  44. data = res.first;
  45. shape = res.second;
  46. }
  47. GoldenTensor::~GoldenTensor() {
  48. if (data) {
  49. utils::free_npy_data(data);
  50. data = nullptr;
  51. }
  52. }
  53. GoldenTensor::GoldenTensor(GoldenTensor&& o) noexcept
  54. : data(o.data), shape(std::move(o.shape)), name(std::move(o.name)) {
  55. o.data = nullptr;
  56. }
  57. GoldenTensor& GoldenTensor::operator=(GoldenTensor&& o) noexcept {
  58. if (this != &o) {
  59. if (data) utils::free_npy_data(data);
  60. data = o.data;
  61. shape = std::move(o.shape);
  62. name = std::move(o.name);
  63. o.data = nullptr;
  64. }
  65. return *this;
  66. }
  67. size_t GoldenTensor::nelements() const {
  68. if (shape.empty()) return 0;
  69. size_t n = 1;
  70. for (size_t dim : shape) n *= dim;
  71. return n;
  72. }
  73. void GoldenTensor::PrintShape(const std::string& prefix) const {
  74. std::cout << prefix << name << " shape: [";
  75. for (size_t i = 0; i < shape.size(); ++i) {
  76. std::cout << shape[i];
  77. if (i < shape.size() - 1) std::cout << ", ";
  78. }
  79. std::cout << "]" << std::endl;
  80. }
  81. //======================================================
  82. // Helper
  83. //======================================================
  84. bool CompareAndReport(
  85. const std::string& name,
  86. const float* expected, size_t n_expected,
  87. const float* actual, size_t n_actual,
  88. float atol,
  89. float rtol
  90. ) {
  91. std::cout << "[Compare] " << name << std::endl;
  92. if (n_expected != n_actual) {
  93. std::cerr << " SIZE MISMATCH: Expected " << n_expected << ", Actual " << n_actual << std::endl;
  94. return false;
  95. }
  96. // Resolve tolerances
  97. if (atol < 0) atol = GetToleranceAtol();
  98. if (rtol < 0) rtol = GetToleranceRtol();
  99. float max_diff = 0.0f;
  100. float sum_diff = 0.0f;
  101. float max_rel_diff = 0.0f;
  102. for (size_t i = 0; i < n_expected; ++i) {
  103. float diff = std::abs(expected[i] - actual[i]);
  104. max_diff = std::max(max_diff, diff);
  105. sum_diff += diff;
  106. if (std::abs(expected[i]) > 1e-8f) {
  107. float rel = diff / std::abs(expected[i]);
  108. max_rel_diff = std::max(max_rel_diff, rel);
  109. }
  110. }
  111. float mean_diff = sum_diff / n_expected;
  112. std::cout << " max_diff: " << max_diff << " (limit " << atol << ")" << std::endl;
  113. std::cout << " mean_diff: " << mean_diff << std::endl;
  114. std::cout << " max_rel_diff: " << max_rel_diff << " (limit " << rtol << ")" << std::endl;
  115. bool match = (max_diff <= atol) || (max_rel_diff <= rtol);
  116. if (match) {
  117. std::cout << " ✓ OK" << std::endl;
  118. } else {
  119. std::cout << " ✗ MISMATCH" << std::endl;
  120. }
  121. return match;
  122. }