|
@@ -15,8 +15,11 @@
|
|
|
#include <thread>
|
|
#include <thread>
|
|
|
#include <mutex>
|
|
#include <mutex>
|
|
|
#include <condition_variable>
|
|
#include <condition_variable>
|
|
|
|
|
+#include <atomic>
|
|
|
|
|
+#include <exception>
|
|
|
|
|
|
|
|
using Complex = std::complex<float>;
|
|
using Complex = std::complex<float>;
|
|
|
|
|
+static constexpr const char* kInferenceCancelledMessage = "Inference cancelled";
|
|
|
|
|
|
|
|
// Helper forward decl
|
|
// Helper forward decl
|
|
|
std::vector<float> GetWindow(int size, int fade_size);
|
|
std::vector<float> GetWindow(int size, int fade_size);
|
|
@@ -289,11 +292,13 @@ void Inference::PostProcessAndISTFT(const std::vector<float>& mask_output,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#include <future>
|
|
|
|
|
-
|
|
|
|
|
-std::vector<std::vector<float>> Inference::Process(const std::vector<float>& input_audio, int chunk_size, int num_overlap, std::function<void(float)> progress_callback) {
|
|
|
|
|
|
|
+std::vector<std::vector<float>> Inference::Process(const std::vector<float>& input_audio,
|
|
|
|
|
+ int chunk_size,
|
|
|
|
|
+ int num_overlap,
|
|
|
|
|
+ std::function<void(float)> progress_callback,
|
|
|
|
|
+ CancelCallback cancel_callback) {
|
|
|
if (input_audio.empty()) return {};
|
|
if (input_audio.empty()) return {};
|
|
|
- return ProcessOverlapAddPipelined(input_audio, chunk_size, num_overlap, progress_callback);
|
|
|
|
|
|
|
+ return ProcessOverlapAddPipelined(input_audio, chunk_size, num_overlap, progress_callback, cancel_callback);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
// =================================================================================================
|
|
@@ -450,7 +455,8 @@ private:
|
|
|
std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std::vector<float>& input_audio,
|
|
std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std::vector<float>& input_audio,
|
|
|
int chunk_size,
|
|
int chunk_size,
|
|
|
int num_overlap,
|
|
int num_overlap,
|
|
|
- std::function<void(float)> progress_callback) {
|
|
|
|
|
|
|
+ std::function<void(float)> progress_callback,
|
|
|
|
|
+ CancelCallback cancel_callback) {
|
|
|
if (input_audio.empty()) return {};
|
|
if (input_audio.empty()) return {};
|
|
|
if (input_audio.size() % 2 != 0) {
|
|
if (input_audio.size() % 2 != 0) {
|
|
|
throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
|
|
throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
|
|
@@ -505,6 +511,7 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
|
|
|
std::vector<float> counter(n_padded_samples * channels, 0.0f);
|
|
std::vector<float> counter(n_padded_samples * channels, 0.0f);
|
|
|
std::vector<float> window_base = GetWindow(chunk_size, fade_size);
|
|
std::vector<float> window_base = GetWindow(chunk_size, fade_size);
|
|
|
std::mutex result_mutex; // Protects 'result' and 'counter'
|
|
std::mutex result_mutex; // Protects 'result' and 'counter'
|
|
|
|
|
+ std::atomic<bool> cancel_requested{false};
|
|
|
|
|
|
|
|
// lambda to extract chunk 'i'
|
|
// lambda to extract chunk 'i'
|
|
|
auto extract_chunk = [&](int i) -> std::vector<float> {
|
|
auto extract_chunk = [&](int i) -> std::vector<float> {
|
|
@@ -590,6 +597,20 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
|
|
|
// 3 items buffer is enough to keep GPU busy
|
|
// 3 items buffer is enough to keep GPU busy
|
|
|
ThreadSafeQueue<std::shared_ptr<ChunkState>> input_queue(3);
|
|
ThreadSafeQueue<std::shared_ptr<ChunkState>> input_queue(3);
|
|
|
ThreadSafeQueue<std::shared_ptr<ChunkState>> output_queue(3);
|
|
ThreadSafeQueue<std::shared_ptr<ChunkState>> output_queue(3);
|
|
|
|
|
+ std::mutex exception_mutex;
|
|
|
|
|
+ std::exception_ptr pipeline_exception = nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ auto set_pipeline_exception = [&](std::exception_ptr eptr) {
|
|
|
|
|
+ {
|
|
|
|
|
+ std::lock_guard<std::mutex> lock(exception_mutex);
|
|
|
|
|
+ if (!pipeline_exception) {
|
|
|
|
|
+ pipeline_exception = eptr;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ cancel_requested.store(true, std::memory_order_release);
|
|
|
|
|
+ input_queue.Shutdown();
|
|
|
|
|
+ output_queue.Shutdown();
|
|
|
|
|
+ };
|
|
|
|
|
|
|
|
// Structure to hold chunk metadata together
|
|
// Structure to hold chunk metadata together
|
|
|
struct ChunkTask {
|
|
struct ChunkTask {
|
|
@@ -599,51 +620,109 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAddPipelined(const std:
|
|
|
|
|
|
|
|
// 1. Preprocessor Thread
|
|
// 1. Preprocessor Thread
|
|
|
auto preproccessor = std::thread([&]() {
|
|
auto preproccessor = std::thread([&]() {
|
|
|
- int current_offset = 0;
|
|
|
|
|
- while (current_offset < n_padded_samples) {
|
|
|
|
|
- std::vector<float> chunk = extract_chunk(current_offset);
|
|
|
|
|
-
|
|
|
|
|
- auto state = PreProcessChunk(chunk, current_offset);
|
|
|
|
|
-
|
|
|
|
|
- input_queue.Push(state);
|
|
|
|
|
- current_offset += step;
|
|
|
|
|
|
|
+ try {
|
|
|
|
|
+ int current_offset = 0;
|
|
|
|
|
+ while (current_offset < n_padded_samples && !cancel_requested.load(std::memory_order_acquire)) {
|
|
|
|
|
+ std::vector<float> chunk = extract_chunk(current_offset);
|
|
|
|
|
+
|
|
|
|
|
+ auto state = PreProcessChunk(chunk, current_offset);
|
|
|
|
|
+
|
|
|
|
|
+ input_queue.Push(state);
|
|
|
|
|
+ if (cancel_requested.load(std::memory_order_acquire)) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ current_offset += step;
|
|
|
|
|
+ }
|
|
|
|
|
+ } catch (...) {
|
|
|
|
|
+ set_pipeline_exception(std::current_exception());
|
|
|
}
|
|
}
|
|
|
input_queue.Shutdown();
|
|
input_queue.Shutdown();
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
// 3. Postprocessor Thread
|
|
// 3. Postprocessor Thread
|
|
|
auto postprocessor = std::thread([&]() {
|
|
auto postprocessor = std::thread([&]() {
|
|
|
- std::shared_ptr<ChunkState> state;
|
|
|
|
|
- while (output_queue.Pop(state)) {
|
|
|
|
|
- // This does ISTFT (CPU intensive)
|
|
|
|
|
- PostProcessChunk(state);
|
|
|
|
|
-
|
|
|
|
|
- // Accumulate (Memory bandwidth intensive + Mutex)
|
|
|
|
|
- accumulate_result(state, state->id); // state->id holds offset
|
|
|
|
|
-
|
|
|
|
|
- if (progress_callback) {
|
|
|
|
|
- float progress = (float)std::min(state->id + step, n_padded_samples) / n_padded_samples;
|
|
|
|
|
- progress_callback(progress);
|
|
|
|
|
|
|
+ try {
|
|
|
|
|
+ std::shared_ptr<ChunkState> state;
|
|
|
|
|
+ while (!cancel_requested.load(std::memory_order_acquire) && output_queue.Pop(state)) {
|
|
|
|
|
+ // This does ISTFT (CPU intensive)
|
|
|
|
|
+ PostProcessChunk(state);
|
|
|
|
|
+ if (cancel_requested.load(std::memory_order_acquire)) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Accumulate (Memory bandwidth intensive + Mutex)
|
|
|
|
|
+ accumulate_result(state, state->id); // state->id holds offset
|
|
|
|
|
+
|
|
|
|
|
+ if (!cancel_requested.load(std::memory_order_acquire) && progress_callback) {
|
|
|
|
|
+ float progress = (float)std::min(state->id + step, n_padded_samples) / n_padded_samples;
|
|
|
|
|
+ progress_callback(progress);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
+ } catch (...) {
|
|
|
|
|
+ set_pipeline_exception(std::current_exception());
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
|
|
+ auto poll_cancel_requested = [&]() -> bool {
|
|
|
|
|
+ if (cancel_requested.load(std::memory_order_acquire)) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (cancel_callback && cancel_callback()) {
|
|
|
|
|
+ cancel_requested.store(true, std::memory_order_release);
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ return false;
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
// 2. Main Thread (Inference Loop)
|
|
// 2. Main Thread (Inference Loop)
|
|
|
|
|
+ bool cancelled = false;
|
|
|
std::shared_ptr<ChunkState> state;
|
|
std::shared_ptr<ChunkState> state;
|
|
|
- while (true) {
|
|
|
|
|
- bool ok = input_queue.Pop(state);
|
|
|
|
|
- if (!ok) break; // Input queue shutdown and empty
|
|
|
|
|
-
|
|
|
|
|
- // This does GGML Inference (GPU intensive, Blocking)
|
|
|
|
|
- RunInference(state);
|
|
|
|
|
-
|
|
|
|
|
- output_queue.Push(state);
|
|
|
|
|
|
|
+ try {
|
|
|
|
|
+ while (true) {
|
|
|
|
|
+ if (poll_cancel_requested()) {
|
|
|
|
|
+ cancelled = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ bool ok = input_queue.Pop(state);
|
|
|
|
|
+ if (!ok) break; // Input queue shutdown and empty
|
|
|
|
|
+
|
|
|
|
|
+ if (poll_cancel_requested()) {
|
|
|
|
|
+ cancelled = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // This does GGML Inference (GPU intensive, Blocking)
|
|
|
|
|
+ RunInference(state);
|
|
|
|
|
+
|
|
|
|
|
+ if (poll_cancel_requested()) {
|
|
|
|
|
+ cancelled = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ output_queue.Push(state);
|
|
|
|
|
+ }
|
|
|
|
|
+ } catch (...) {
|
|
|
|
|
+ set_pipeline_exception(std::current_exception());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (cancelled) {
|
|
|
|
|
+ cancel_requested.store(true, std::memory_order_release);
|
|
|
|
|
+ input_queue.Shutdown();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// Wait for threads
|
|
// Wait for threads
|
|
|
output_queue.Shutdown();
|
|
output_queue.Shutdown();
|
|
|
if (preproccessor.joinable()) preproccessor.join();
|
|
if (preproccessor.joinable()) preproccessor.join();
|
|
|
if (postprocessor.joinable()) postprocessor.join();
|
|
if (postprocessor.joinable()) postprocessor.join();
|
|
|
|
|
+
|
|
|
|
|
+ if (pipeline_exception) {
|
|
|
|
|
+ std::rethrow_exception(pipeline_exception);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (cancel_requested.load(std::memory_order_acquire)) {
|
|
|
|
|
+ throw std::runtime_error(kInferenceCancelledMessage);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// Normalize and Crop
|
|
// Normalize and Crop
|
|
|
// result is [stems][samples]
|
|
// result is [stems][samples]
|
|
@@ -676,7 +755,8 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAdd(const std::vector<f
|
|
|
int chunk_size,
|
|
int chunk_size,
|
|
|
int num_overlap,
|
|
int num_overlap,
|
|
|
ModelCallback model_func,
|
|
ModelCallback model_func,
|
|
|
- std::function<void(float)> progress_callback) {
|
|
|
|
|
|
|
+ std::function<void(float)> progress_callback,
|
|
|
|
|
+ CancelCallback cancel_callback) {
|
|
|
if (input_audio.empty()) return {};
|
|
if (input_audio.empty()) return {};
|
|
|
if (input_audio.size() % 2 != 0) {
|
|
if (input_audio.size() % 2 != 0) {
|
|
|
throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
|
|
throw std::runtime_error("Error: Input audio must be interleaved stereo (even number of samples).");
|
|
@@ -736,6 +816,10 @@ std::vector<std::vector<float>> Inference::ProcessOverlapAdd(const std::vector<f
|
|
|
int total_length = n_padded_samples;
|
|
int total_length = n_padded_samples;
|
|
|
|
|
|
|
|
while (i < total_length) {
|
|
while (i < total_length) {
|
|
|
|
|
+ if (cancel_callback && cancel_callback()) {
|
|
|
|
|
+ throw std::runtime_error(kInferenceCancelledMessage);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
int remaining = total_length - i;
|
|
int remaining = total_length - i;
|
|
|
int part_len = std::min(C, remaining); // Logic matches Python slice [i:i+C]
|
|
int part_len = std::min(C, remaining); // Logic matches Python slice [i:i+C]
|
|
|
|
|
|