ソースを参照

feat(inference): add cancellation callback support to Process methods

Add CancelCallback type and integrate cancellation checks throughout
the inference pipeline. This allows users to cancel long-running
inference operations by providing a callback that returns true when
cancellation is requested.

Changes include:
- Add CancelCallback typedef and cancel_callback parameter to
  Process(), ProcessOverlapAdd(), and ProcessOverlapAddPipelined()
- Implement atomic cancellation state and exception handling in
  pipeline threads with graceful shutdown
- Add comprehensive unit tests for immediate, delayed, and
  non-cancellation scenarios
- Update README documentation with cancel callback usage examples

The new cancel_callback parameter is optional (defaults to nullptr)
maintaining backward compatibility with existing code.
沉默の金 4 ヶ月 前
コミット
30d648c2e3

+ 8 - 1
README.md

@@ -157,6 +157,7 @@ python scripts/convert_to_gguf.py ... --arch bs
 ## 💻 C++ API
 
 ```cpp
+#include <atomic>
 #include <bs_roformer/inference.h>
 #include <bs_roformer/audio.h>
 
@@ -170,10 +171,14 @@ Inference engine("model.gguf");
 int chunk_size = engine.GetDefaultChunkSize();   // e.g., 352800
 int num_overlap = engine.GetDefaultNumOverlap(); // e.g., 2
 
-// 4. Run inference (with progress callback)
+// 4. Run inference (with progress + cancel callback)
+std::atomic<bool> should_cancel{false};
 auto stems = engine.Process(input.data, chunk_size, num_overlap,
     [](float progress) {
         std::cout << "Progress: " << int(progress * 100) << "%" << std::endl;
+    },
+    [&should_cancel]() {
+        return should_cancel.load();
     });
 
 // 5. Save result
@@ -181,6 +186,8 @@ AudioBuffer output{stems[0], 2, 44100, stems[0].size()};
 AudioFile::Save("vocals.wav", output);
 ```
 
+If `cancel_callback` returns `true`, `Process()` throws `std::runtime_error("Inference cancelled")`.
+
 ---
 
 ## 🏗️ Project Architecture

+ 8 - 1
README.zh.md

@@ -155,6 +155,7 @@ python scripts/convert_to_gguf.py ... --arch bs
 ## 💻 C++ API
 
 ```cpp
+#include <atomic>
 #include <bs_roformer/inference.h>
 #include <bs_roformer/audio.h>
 
@@ -168,10 +169,14 @@ Inference engine("model.gguf");
 int chunk_size = engine.GetDefaultChunkSize();   // 如 352800
 int num_overlap = engine.GetDefaultNumOverlap(); // 如 2
 
-// 4. 执行推理(带进度回调)
+// 4. 执行推理(带进度回调 + 取消回调)
+std::atomic<bool> should_cancel{false};
 auto stems = engine.Process(input.data, chunk_size, num_overlap,
     [](float progress) {
         std::cout << "Progress: " << int(progress * 100) << "%" << std::endl;
+    },
+    [&should_cancel]() {
+        return should_cancel.load();
     });
 
 // 5. 保存结果
@@ -179,6 +184,8 @@ AudioBuffer output{stems[0], 2, 44100, stems[0].size()};
 AudioFile::Save("vocals.wav", output);
 ```
 
+当 `cancel_callback` 返回 `true` 时,`Process()` 会抛出 `std::runtime_error("Inference cancelled")`。
+
 ---
 
 ## 🏗️ 项目架构

+ 8 - 3
include/bs_roformer/inference.h

@@ -12,6 +12,8 @@ namespace ggml { struct context; struct cgraph; }
 
 class Inference {
 public:
+    using CancelCallback = std::function<bool()>;
+
     Inference(const std::string& model_path);
     ~Inference();
 
@@ -22,7 +24,8 @@ public:
     std::vector<std::vector<float>> Process(const std::vector<float>& input_audio, 
                                int chunk_size = 352800, 
                                int num_overlap = 2,
-                               std::function<void(float)> progress_callback = nullptr);
+                               std::function<void(float)> progress_callback = nullptr,
+                               CancelCallback cancel_callback = nullptr);
 
     // Low-level chunk processing (public for testing)
     std::vector<std::vector<float>> ProcessChunk(const std::vector<float>& chunk_audio);
@@ -39,14 +42,16 @@ public:
                                                 int chunk_size, 
                                                 int num_overlap,
                                                 ModelCallback model_func,
-                                                std::function<void(float)> progress_callback = nullptr); // Added callback
+                                                std::function<void(float)> progress_callback = nullptr,
+                                                CancelCallback cancel_callback = nullptr);
 
 private:
     // Pipelined Overlap-Add
     std::vector<std::vector<float>> ProcessOverlapAddPipelined(const std::vector<float>& input_audio, 
                                                   int chunk_size, 
                                                   int num_overlap,
-                                                  std::function<void(float)> progress_callback);
+                                                  std::function<void(float)> progress_callback,
+                                                  CancelCallback cancel_callback);
 
 private:
     std::unique_ptr<BSRoformer> model_;

+ 117 - 33
src/inference.cpp

@@ -15,8 +15,11 @@
 #include <thread>
 #include <mutex>
 #include <condition_variable>
+#include <atomic>
+#include <exception>
 
 using Complex = std::complex<float>;
+static constexpr const char* kInferenceCancelledMessage = "Inference cancelled";
 
 // Helper forward decl
 std::vector<float> GetWindow(int size, int fade_size);
@@ -289,11 +292,13 @@ void Inference::PostProcessAndISTFT(const std::vector<float>& mask_output,
     }
 }
 
-#include <future>
-
-std::vector<std::vector<float>> Inference::Process(const std::vector<float>& input_audio, int chunk_size, int num_overlap, std::function<void(float)> progress_callback) {
+std::vector<std::vector<float>> Inference::Process(const std::vector<float>& input_audio,
+                                                   int chunk_size,
+                                                   int num_overlap,
+                                                   std::function<void(float)> progress_callback,
+                                                   CancelCallback cancel_callback) {
     if (input_audio.empty()) return {};
-    return ProcessOverlapAddPipelined(input_audio, chunk_size, num_overlap, progress_callback);
+    return ProcessOverlapAddPipelined(input_audio, chunk_size, num_overlap, progress_callback, cancel_callback);
 }
 
 // =================================================================================================
@@ -450,7 +455,8 @@ private:
 std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std::vector<float>& input_audio, 
                                                          int chunk_size, 
                                                          int num_overlap,
-                                                         std::function<void(float)> progress_callback) {
+                                                         std::function<void(float)> progress_callback,
+                                                         CancelCallback cancel_callback) {
     if (input_audio.empty()) return {};
     if (input_audio.size() % 2 != 0) {
         throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
@@ -505,6 +511,7 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
     std::vector<float> counter(n_padded_samples * channels, 0.0f);
     std::vector<float> window_base = GetWindow(chunk_size, fade_size);
     std::mutex result_mutex; // Protects 'result' and 'counter'
+    std::atomic<bool> cancel_requested{false};
     
     // lambda to extract chunk 'i'
     auto extract_chunk = [&](int i) -> std::vector<float> {
@@ -590,6 +597,20 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
     // 3 items buffer is enough to keep GPU busy
     ThreadSafeQueue<std::shared_ptr<ChunkState>> input_queue(3);
     ThreadSafeQueue<std::shared_ptr<ChunkState>> output_queue(3);
+    std::mutex exception_mutex;
+    std::exception_ptr pipeline_exception = nullptr;
+
+    auto set_pipeline_exception = [&](std::exception_ptr eptr) {
+        {
+            std::lock_guard<std::mutex> lock(exception_mutex);
+            if (!pipeline_exception) {
+                pipeline_exception = eptr;
+            }
+        }
+        cancel_requested.store(true, std::memory_order_release);
+        input_queue.Shutdown();
+        output_queue.Shutdown();
+    };
     
     // Structure to hold chunk metadata together
     struct ChunkTask {
@@ -599,51 +620,109 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
     
     // 1. Preprocessor Thread
     auto preproccessor = std::thread([&]() {
-        int current_offset = 0;
-        while (current_offset < n_padded_samples) {
-            std::vector<float> chunk = extract_chunk(current_offset);
-            
-            auto state = PreProcessChunk(chunk, current_offset); 
-            
-            input_queue.Push(state);
-            current_offset += step;
+        try {
+            int current_offset = 0;
+            while (current_offset < n_padded_samples && !cancel_requested.load(std::memory_order_acquire)) {
+                std::vector<float> chunk = extract_chunk(current_offset);
+                
+                auto state = PreProcessChunk(chunk, current_offset); 
+                
+                input_queue.Push(state);
+                if (cancel_requested.load(std::memory_order_acquire)) {
+                    break;
+                }
+                current_offset += step;
+            }
+        } catch (...) {
+            set_pipeline_exception(std::current_exception());
         }
         input_queue.Shutdown();
     });
     
     // 3. Postprocessor Thread
     auto postprocessor = std::thread([&]() {
-        std::shared_ptr<ChunkState> state;
-        while (output_queue.Pop(state)) {
-            // This does ISTFT (CPU intensive)
-            PostProcessChunk(state);
-            
-            // Accumulate (Memory bandwidth intensive + Mutex)
-            accumulate_result(state, state->id); // state->id holds offset
-            
-            if (progress_callback) {
-                float progress = (float)std::min(state->id + step, n_padded_samples) / n_padded_samples;
-                progress_callback(progress);
+        try {
+            std::shared_ptr<ChunkState> state;
+            while (!cancel_requested.load(std::memory_order_acquire) && output_queue.Pop(state)) {
+                // This does ISTFT (CPU intensive)
+                PostProcessChunk(state);
+                if (cancel_requested.load(std::memory_order_acquire)) {
+                    break;
+                }
+                
+                // Accumulate (Memory bandwidth intensive + Mutex)
+                accumulate_result(state, state->id); // state->id holds offset
+                
+                if (!cancel_requested.load(std::memory_order_acquire) && progress_callback) {
+                    float progress = (float)std::min(state->id + step, n_padded_samples) / n_padded_samples;
+                    progress_callback(progress);
+                }
             }
+        } catch (...) {
+            set_pipeline_exception(std::current_exception());
         }
     });
     
+    auto poll_cancel_requested = [&]() -> bool {
+        if (cancel_requested.load(std::memory_order_acquire)) {
+            return true;
+        }
+        if (cancel_callback && cancel_callback()) {
+            cancel_requested.store(true, std::memory_order_release);
+            return true;
+        }
+        return false;
+    };
+
     // 2. Main Thread (Inference Loop)
+    bool cancelled = false;
     std::shared_ptr<ChunkState> state;
-    while (true) {
-        bool ok = input_queue.Pop(state);
-        if (!ok) break; // Input queue shutdown and empty
-        
-        // This does GGML Inference (GPU intensive, Blocking)
-        RunInference(state);
-        
-        output_queue.Push(state);
+    try {
+        while (true) {
+            if (poll_cancel_requested()) {
+                cancelled = true;
+                break;
+            }
+
+            bool ok = input_queue.Pop(state);
+            if (!ok) break; // Input queue shutdown and empty
+
+            if (poll_cancel_requested()) {
+                cancelled = true;
+                break;
+            }
+            
+            // This does GGML Inference (GPU intensive, Blocking)
+            RunInference(state);
+
+            if (poll_cancel_requested()) {
+                cancelled = true;
+                break;
+            }
+            
+            output_queue.Push(state);
+        }
+    } catch (...) {
+        set_pipeline_exception(std::current_exception());
     }
     
+    if (cancelled) {
+        cancel_requested.store(true, std::memory_order_release);
+        input_queue.Shutdown();
+    }
+
     // Wait for threads
     output_queue.Shutdown();
     if (preproccessor.joinable()) preproccessor.join();
     if (postprocessor.joinable()) postprocessor.join();
+
+    if (pipeline_exception) {
+        std::rethrow_exception(pipeline_exception);
+    }
+
+    if (cancel_requested.load(std::memory_order_acquire)) {
+        throw std::runtime_error(kInferenceCancelledMessage);
+    }
     
     // Normalize and Crop
     // result is [stems][samples]
@@ -676,7 +755,8 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAdd(const std::vector<f
                                                 int chunk_size, 
                                                 int num_overlap,
                                                 ModelCallback model_func,
-                                                std::function<void(float)> progress_callback) {
+                                                std::function<void(float)> progress_callback,
+                                                CancelCallback cancel_callback) {
     if (input_audio.empty()) return {};
     if (input_audio.size() % 2 != 0) {
         throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
@@ -736,6 +816,10 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAdd(const std::vector<f
     int total_length = n_padded_samples;
     
     while (i < total_length) {
+        if (cancel_callback && cancel_callback()) {
+            throw std::runtime_error(kInferenceCancelledMessage);
+        }
+
         int remaining = total_length - i;
         int part_len = std::min(C, remaining); // Logic matches Python slice [i:i+C]
         

+ 1 - 0
tests/CMakeLists.txt

@@ -45,3 +45,4 @@ bsr_add_test(test_component_mask)
 bsr_add_test(test_inference)
 bsr_add_test(test_chunking_logic)
 bsr_add_test(test_stft_consistency)
+bsr_add_test(test_cancel_callback)

+ 115 - 0
tests/test_cancel_callback.cpp

@@ -0,0 +1,115 @@
+#include <algorithm>
+#include <cmath>
+#include <exception>
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "bs_roformer/inference.h"
+
+static bool IsCancelledError(const std::exception& e) {
+    return std::string(e.what()) == "Inference cancelled";
+}
+
+int main() {
+    std::cout << "Test: Cancel Callback Behavior" << std::endl;
+
+    const int channels = 2;
+    const int samples = 96;
+    const int chunk_size = 32;
+    const int num_overlap = 2;
+
+    std::vector<float> input(samples * channels);
+    for (int i = 0; i < samples; ++i) {
+        input[i * channels + 0] = std::sin(0.1f * static_cast<float>(i));
+        input[i * channels + 1] = std::cos(0.1f * static_cast<float>(i));
+    }
+
+    auto identity = [](const std::vector<float>& chunk) {
+        return std::vector<std::vector<float>>{chunk};
+    };
+
+    // Case 1: immediate cancellation
+    bool immediate_cancelled = false;
+    try {
+        (void)Inference::ProcessOverlapAdd(
+            input,
+            chunk_size,
+            num_overlap,
+            identity,
+            nullptr,
+            []() { return true; });
+    } catch (const std::exception& e) {
+        if (!IsCancelledError(e)) {
+            std::cerr << "Unexpected exception for immediate cancel: " << e.what() << std::endl;
+            return 1;
+        }
+        immediate_cancelled = true;
+    }
+
+    if (!immediate_cancelled) {
+        std::cerr << "Immediate cancellation did not throw" << std::endl;
+        return 1;
+    }
+
+    // Case 2: delayed cancellation
+    int cancel_calls = 0;
+    bool delayed_cancelled = false;
+    try {
+        (void)Inference::ProcessOverlapAdd(
+            input,
+            chunk_size,
+            num_overlap,
+            identity,
+            nullptr,
+            [&cancel_calls]() {
+                ++cancel_calls;
+                return cancel_calls >= 3;
+            });
+    } catch (const std::exception& e) {
+        if (!IsCancelledError(e)) {
+            std::cerr << "Unexpected exception for delayed cancel: " << e.what() << std::endl;
+            return 1;
+        }
+        delayed_cancelled = true;
+    }
+
+    if (!delayed_cancelled) {
+        std::cerr << "Delayed cancellation did not throw" << std::endl;
+        return 1;
+    }
+
+    // Case 3: cancel callback always false should match baseline output.
+    auto no_cancel = []() { return false; };
+    auto baseline = Inference::ProcessOverlapAdd(input, chunk_size, num_overlap, identity);
+    auto with_no_cancel = Inference::ProcessOverlapAdd(
+        input,
+        chunk_size,
+        num_overlap,
+        identity,
+        nullptr,
+        no_cancel);
+
+    if (baseline.size() != with_no_cancel.size() || baseline.empty()) {
+        std::cerr << "Output stem count mismatch in no-cancel path" << std::endl;
+        return 1;
+    }
+
+    if (baseline[0].size() != with_no_cancel[0].size()) {
+        std::cerr << "Output sample count mismatch in no-cancel path" << std::endl;
+        return 1;
+    }
+
+    float max_diff = 0.0f;
+    for (size_t i = 0; i < baseline[0].size(); ++i) {
+        max_diff = std::max(max_diff, std::abs(baseline[0][i] - with_no_cancel[0][i]));
+    }
+
+    if (max_diff > 1e-6f) {
+        std::cerr << "No-cancel output mismatch, max diff = " << max_diff << std::endl;
+        return 1;
+    }
+
+    std::cout << "PASSED" << std::endl;
+    return 0;
+}

+ 16 - 0
tests/test_inference.cpp

@@ -3,6 +3,7 @@
 #include <cmath>
 #include <string>
 #include <cstdlib>
+#include <algorithm>
 #include "bs_roformer/inference.h"
 #include "../src/utils.h"
 
@@ -55,6 +56,21 @@ int main(int argc, char* argv[]) {
         // This matches the generation of output_audio.npy
         std::vector<std::vector<float>> output_stems = engine.ProcessChunk(input_audio);
         std::vector<float> output_audio = output_stems[0];
+
+        // Smoke test new cancel callback path in Process()
+        size_t smoke_samples = std::min<size_t>(input_audio.size(), static_cast<size_t>(16384 * 2));
+        if (smoke_samples % 2 != 0) {
+            smoke_samples -= 1;
+        }
+        if (smoke_samples >= 2) {
+            std::vector<float> smoke_input(input_audio.begin(), input_audio.begin() + smoke_samples);
+            auto cancel_false = []() { return false; };
+            auto smoke_stems = engine.Process(smoke_input, 16384, 2, nullptr, cancel_false);
+            if (smoke_stems.empty() || smoke_stems[0].empty()) {
+                std::cerr << "Process() smoke test returned empty output" << std::endl;
+                return 1;
+            }
+        }
         
         std::cout << "  Input size: " << input_audio.size() << std::endl;
         std::cout << "  Output size: " << output_audio.size() << std::endl;