#pragma once #include #include #include #include // Forward declaration class MelBandRoformer; // Forward declaration namespace ggml { struct context; struct cgraph; } class Inference { public: Inference(const std::string& model_path); ~Inference(); // Process a full audio track (interleaved stereo float32) // Uses overlap-add chunking to handle long files // Process a full audio track (interleaved stereo float32) // Returns a vector of stems, where each stem is an interleaved stereo float vector std::vector> Process(const std::vector& input_audio, int chunk_size = 352800, int num_overlap = 2, std::function progress_callback = nullptr); // Low-level chunk processing (public for testing) std::vector> ProcessChunk(const std::vector& chunk_audio); // Get model's recommended inference defaults int GetDefaultChunkSize() const; int GetDefaultNumOverlap() const; // Static helper for Overlap-Add logic (matches Python exactly) // model_func: input [samples], output [stems][samples] (interleaved stereo) using ModelCallback = std::function>(const std::vector&)>; static std::vector> ProcessOverlapAdd(const std::vector& input_audio, int chunk_size, int num_overlap, ModelCallback model_func, std::function progress_callback = nullptr); // Added callback private: // Pipelined Overlap-Add std::vector> ProcessOverlapAddPipelined(const std::vector& input_audio, int chunk_size, int num_overlap, std::function progress_callback); private: std::unique_ptr model_; // Persistent Graph State struct ggml_context* ctx_ = nullptr; struct ggml_cgraph* gf_ = nullptr; struct ggml_gallocr* allocr_ = nullptr; // Cached Input Tensors (owned by ctx_) struct ggml_tensor* input_tensor_ = nullptr; struct ggml_tensor* pos_time_ = nullptr; struct ggml_tensor* pos_freq_ = nullptr; struct ggml_tensor* mask_out_tensor_ = nullptr; // Current config state int cached_n_frames_ = -1; // Pipelined State Data struct ChunkState { int id = -1; std::vector input_audio; // Original chunk audio std::vector stft_flattened; // [Prepared Input for GPU] std::vector> stft_outputs; // Kept for reconstruction int n_frames = 0; std::vector mask_output; // Output from GPU std::vector> final_audio; // Result after ISTFT [stems][samples] }; // Helper to ensure graph is built for specific n_frames bool EnsureGraph(int n_frames); void ComputeSTFT(const std::vector& input_audio, std::vector>& stft_outputs, int& n_frames); void PrepareModelInput(const std::vector>& stft_outputs, int n_frames, std::vector& model_input_rearranged); void PostProcessAndISTFT(const std::vector& mask_output, const std::vector>& stft_outputs, int n_frames, std::vector>& output_audio); // Pipeline Steps std::shared_ptr PreProcessChunk(const std::vector& chunk_audio, int id); void RunInference(std::shared_ptr state); void PostProcessChunk(std::shared_ptr state); };