inference.h 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #pragma once
  2. #include <vector>
  3. #include <string>
  4. #include <memory>
  5. #include <functional>
  6. // Forward declaration
  7. class MelBandRoformer;
  8. // Forward declaration
  9. namespace ggml { struct context; struct cgraph; }
  10. class Inference {
  11. public:
  12. Inference(const std::string& model_path);
  13. ~Inference();
  14. // Process a full audio track (interleaved stereo float32)
  15. // Uses overlap-add chunking to handle long files
  16. // Process a full audio track (interleaved stereo float32)
  17. // Returns a vector of stems, where each stem is an interleaved stereo float vector
  18. std::vector<std::vector<float>> Process(const std::vector<float>& input_audio,
  19. int chunk_size = 352800,
  20. int num_overlap = 2,
  21. std::function<void(float)> progress_callback = nullptr);
  22. // Low-level chunk processing (public for testing)
  23. std::vector<std::vector<float>> ProcessChunk(const std::vector<float>& chunk_audio);
  24. // Get model's recommended inference defaults
  25. int GetDefaultChunkSize() const;
  26. int GetDefaultNumOverlap() const;
  27. // Static helper for Overlap-Add logic (matches Python exactly)
  28. // model_func: input [samples], output [stems][samples] (interleaved stereo)
  29. using ModelCallback = std::function<std::vector<std::vector<float>>(const std::vector<float>&)>;
  30. static std::vector<std::vector<float>> ProcessOverlapAdd(const std::vector<float>& input_audio,
  31. int chunk_size,
  32. int num_overlap,
  33. ModelCallback model_func,
  34. std::function<void(float)> progress_callback = nullptr); // Added callback
  35. private:
  36. // Pipelined Overlap-Add
  37. std::vector<std::vector<float>> ProcessOverlapAddPipelined(const std::vector<float>& input_audio,
  38. int chunk_size,
  39. int num_overlap,
  40. std::function<void(float)> progress_callback);
  41. private:
  42. std::unique_ptr<MelBandRoformer> model_;
  43. // Persistent Graph State
  44. struct ggml_context* ctx_ = nullptr;
  45. struct ggml_cgraph* gf_ = nullptr;
  46. struct ggml_gallocr* allocr_ = nullptr;
  47. // Cached Input Tensors (owned by ctx_)
  48. struct ggml_tensor* input_tensor_ = nullptr;
  49. struct ggml_tensor* pos_time_ = nullptr;
  50. struct ggml_tensor* pos_freq_ = nullptr;
  51. struct ggml_tensor* mask_out_tensor_ = nullptr;
  52. // Current config state
  53. int cached_n_frames_ = -1;
  54. // Pipelined State Data
  55. struct ChunkState {
  56. int id = -1;
  57. std::vector<float> input_audio; // Original chunk audio
  58. std::vector<float> stft_flattened; // [Prepared Input for GPU]
  59. std::vector<std::vector<float>> stft_outputs; // Kept for reconstruction
  60. int n_frames = 0;
  61. std::vector<float> mask_output; // Output from GPU
  62. std::vector<std::vector<float>> final_audio; // Result after ISTFT [stems][samples]
  63. };
  64. // Helper to ensure graph is built for specific n_frames
  65. bool EnsureGraph(int n_frames);
  66. void ComputeSTFT(const std::vector<float>& input_audio,
  67. std::vector<std::vector<float>>& stft_outputs,
  68. int& n_frames);
  69. void PrepareModelInput(const std::vector<std::vector<float>>& stft_outputs,
  70. int n_frames,
  71. std::vector<float>& model_input_rearranged);
  72. void PostProcessAndISTFT(const std::vector<float>& mask_output,
  73. const std::vector<std::vector<float>>& stft_outputs,
  74. int n_frames,
  75. std::vector<std::vector<float>>& output_audio);
  76. // Pipeline Steps
  77. std::shared_ptr<ChunkState> PreProcessChunk(const std::vector<float>& chunk_audio, int id);
  78. void RunInference(std::shared_ptr<ChunkState> state);
  79. void PostProcessChunk(std::shared_ptr<ChunkState> state);
  80. };