test_cancel_callback.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #include <algorithm>
  2. #include <cmath>
  3. #include <exception>
  4. #include <iostream>
  5. #include <string>
  6. #include <vector>
  7. #include "bs_roformer/inference.h"
  8. static bool IsCancelledError(const std::exception& e) {
  9. return std::string(e.what()) == "Inference cancelled";
  10. }
  11. int main() {
  12. std::cout << "Test: Cancel Callback Behavior" << std::endl;
  13. const int channels = 2;
  14. const int samples = 96;
  15. const int chunk_size = 32;
  16. const int num_overlap = 2;
  17. std::vector<float> input(samples * channels);
  18. for (int i = 0; i < samples; ++i) {
  19. input[i * channels + 0] = std::sin(0.1f * static_cast<float>(i));
  20. input[i * channels + 1] = std::cos(0.1f * static_cast<float>(i));
  21. }
  22. auto identity = [](const std::vector<float>& chunk) {
  23. return std::vector<std::vector<float>>{chunk};
  24. };
  25. // Case 1: immediate cancellation
  26. bool immediate_cancelled = false;
  27. try {
  28. (void)Inference::ProcessOverlapAdd(
  29. input,
  30. chunk_size,
  31. num_overlap,
  32. identity,
  33. nullptr,
  34. []() { return true; });
  35. } catch (const std::exception& e) {
  36. if (!IsCancelledError(e)) {
  37. std::cerr << "Unexpected exception for immediate cancel: " << e.what() << std::endl;
  38. return 1;
  39. }
  40. immediate_cancelled = true;
  41. }
  42. if (!immediate_cancelled) {
  43. std::cerr << "Immediate cancellation did not throw" << std::endl;
  44. return 1;
  45. }
  46. // Case 2: delayed cancellation
  47. int cancel_calls = 0;
  48. bool delayed_cancelled = false;
  49. try {
  50. (void)Inference::ProcessOverlapAdd(
  51. input,
  52. chunk_size,
  53. num_overlap,
  54. identity,
  55. nullptr,
  56. [&cancel_calls]() {
  57. ++cancel_calls;
  58. return cancel_calls >= 3;
  59. });
  60. } catch (const std::exception& e) {
  61. if (!IsCancelledError(e)) {
  62. std::cerr << "Unexpected exception for delayed cancel: " << e.what() << std::endl;
  63. return 1;
  64. }
  65. delayed_cancelled = true;
  66. }
  67. if (!delayed_cancelled) {
  68. std::cerr << "Delayed cancellation did not throw" << std::endl;
  69. return 1;
  70. }
  71. // Case 3: cancel callback always false should match baseline output.
  72. auto no_cancel = []() { return false; };
  73. auto baseline = Inference::ProcessOverlapAdd(input, chunk_size, num_overlap, identity);
  74. auto with_no_cancel = Inference::ProcessOverlapAdd(
  75. input,
  76. chunk_size,
  77. num_overlap,
  78. identity,
  79. nullptr,
  80. no_cancel);
  81. if (baseline.size() != with_no_cancel.size() || baseline.empty()) {
  82. std::cerr << "Output stem count mismatch in no-cancel path" << std::endl;
  83. return 1;
  84. }
  85. if (baseline[0].size() != with_no_cancel[0].size()) {
  86. std::cerr << "Output sample count mismatch in no-cancel path" << std::endl;
  87. return 1;
  88. }
  89. float max_diff = 0.0f;
  90. for (size_t i = 0; i < baseline[0].size(); ++i) {
  91. max_diff = std::max(max_diff, std::abs(baseline[0][i] - with_no_cancel[0][i]));
  92. }
  93. if (max_diff > 1e-6f) {
  94. std::cerr << "No-cancel output mismatch, max diff = " << max_diff << std::endl;
  95. return 1;
  96. }
  97. std::cout << "PASSED" << std::endl;
  98. return 0;
  99. }