test_common.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #pragma once
  2. #include <iostream>
  3. #include <vector>
  4. #include <cmath>
  5. #include <cstdlib>
  6. #include <string>
  7. #include <algorithm>
  8. #include <ggml.h>
  9. #include <ggml-alloc.h>
  10. #include <ggml-backend.h>
  11. #include "../src/model.h"
  12. #include "../src/utils.h"
  13. //======================================================
  14. // 配置获取
  15. //======================================================
  16. inline std::string GetTestDataDir() {
  17. const char* env = std::getenv("MBR_TEST_DATA_DIR");
  18. return env ? env : ".";
  19. }
  20. inline std::string GetModelPath() {
  21. const char* env = std::getenv("MBR_MODEL_PATH");
  22. return env ? env : "bs_roformer.gguf";
  23. }
  24. inline float GetToleranceAtol() {
  25. const char* env = std::getenv("MBR_TEST_ATOL");
  26. return env ? std::stof(env) : 1e-3f;
  27. }
  28. inline float GetToleranceRtol() {
  29. const char* env = std::getenv("MBR_TEST_RTOL");
  30. return env ? std::stof(env) : 1e-2f;
  31. }
  32. //======================================================
  33. // RAII 测试上下文 (TestContext)
  34. //======================================================
  35. struct TestContext {
  36. ggml_context* ctx = nullptr;
  37. ggml_cgraph* gf = nullptr;
  38. ggml_gallocr_t allocr = nullptr;
  39. BSRoformer* model = nullptr;
  40. // 初始化上下文和图
  41. TestContext(BSRoformer* m, size_t mem_size = 512 * 1024 * 1024);
  42. // 析构自动释放资源
  43. ~TestContext();
  44. // 分配图内存 (VRAM/RAM)
  45. bool AllocateGraph();
  46. // 执行计算
  47. void Compute();
  48. // 安全读取张量数据 (自动处理 GPU->CPU 拷贝)
  49. std::vector<float> ReadTensor(ggml_tensor* t);
  50. };
  51. //======================================================
  52. // RAII Golden Data 加载器
  53. //======================================================
  54. struct GoldenTensor {
  55. float* data = nullptr;
  56. std::vector<size_t> shape;
  57. std::string name;
  58. GoldenTensor() = default;
  59. // 从 dir/activations/{name}.npy 加载
  60. GoldenTensor(const std::string& dir, const std::string& name);
  61. ~GoldenTensor();
  62. // 禁止拷贝
  63. GoldenTensor(const GoldenTensor&) = delete;
  64. GoldenTensor& operator=(const GoldenTensor&) = delete;
  65. // 允许移动
  66. GoldenTensor(GoldenTensor&& o) noexcept;
  67. GoldenTensor& operator=(GoldenTensor&& o) noexcept;
  68. bool valid() const { return data != nullptr; }
  69. size_t nelements() const;
  70. // 打印形状
  71. void PrintShape(const std::string& prefix = "") const;
  72. };
  73. //======================================================
  74. // 断言宏
  75. //======================================================
  76. #define TEST_ASSERT(cond, msg) \
  77. do { \
  78. if (!(cond)) { \
  79. std::cerr << "\n[ASSERT FAILED] " << msg << std::endl; \
  80. std::cerr << " File: " << __FILE__ << ":" << __LINE__ << std::endl; \
  81. return 1; \
  82. } \
  83. } while(0)
  84. #define TEST_ASSERT_LOAD(tensor, name) \
  85. TEST_ASSERT((tensor).valid(), "Failed to load " name ".npy from " + GetTestDataDir())
  86. //======================================================
  87. // 辅助函数
  88. //======================================================
  89. // 比较结果并打印报告
  90. bool CompareAndReport(
  91. const std::string& name,
  92. const float* expected, size_t n_expected,
  93. const float* actual, size_t n_actual,
  94. float atol = -1.0f, // < 0 means use default/env
  95. float rtol = -1.0f
  96. );
  97. // 日志宏
  98. #define LOG_STEP(step, total, msg) \
  99. std::cout << "\n[" << step << "/" << total << "] " << msg << std::endl
  100. #define LOG_PASS() std::cout << "\n✓ PASSED" << std::endl
  101. #define LOG_FAIL() std::cout << "\n✗ FAILED" << std::endl