inference.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #pragma once
  2. #include <cstdint>
  3. #include <functional>
  4. #include <memory>
  5. #include <string>
  6. #include <vector>
  7. // Forward declaration
  8. class BSRoformer;
  9. // Forward declaration
  10. namespace ggml {
  11. struct context;
  12. struct cgraph;
  13. }
  14. class Inference {
  15. public:
  16. using CancelCallback = std::function<bool()>;
  17. Inference(const std::string& model_path);
  18. ~Inference();
  19. // Process a full audio track (interleaved stereo float32)
  20. // Uses overlap-add chunking to handle long files
  21. // Process a full audio track (interleaved stereo float32)
  22. // Returns a vector of stems, where each stem is an interleaved stereo float vector
  23. std::vector<std::vector<float>> Process(const std::vector<float>& input_audio,
  24. int chunk_size = 352800,
  25. int num_overlap = 2,
  26. std::function<void(float)> progress_callback = nullptr,
  27. CancelCallback cancel_callback = nullptr);
  28. // Low-level chunk processing (public for testing)
  29. std::vector<std::vector<float>> ProcessChunk(const std::vector<float>& chunk_audio);
  30. // Get model's recommended inference defaults
  31. int GetDefaultChunkSize() const;
  32. int GetDefaultNumOverlap() const;
  33. int GetSampleRate() const;
  34. int GetNumStems() const;
  35. // Static helper for Overlap-Add logic (matches Python exactly)
  36. // model_func: input [samples], output [stems][samples] (interleaved stereo)
  37. using ModelCallback = std::function<std::vector<std::vector<float>>(const std::vector<float>&)>;
  38. static std::vector<std::vector<float>> ProcessOverlapAdd(const std::vector<float>& input_audio,
  39. int chunk_size,
  40. int num_overlap,
  41. ModelCallback model_func,
  42. std::function<void(float)> progress_callback = nullptr,
  43. CancelCallback cancel_callback = nullptr);
  44. private:
  45. // Pipelined Overlap-Add
  46. std::vector<std::vector<float>> ProcessOverlapAddPipelined(const std::vector<float>& input_audio,
  47. int chunk_size,
  48. int num_overlap,
  49. std::function<void(float)> progress_callback,
  50. CancelCallback cancel_callback);
  51. private:
  52. std::unique_ptr<BSRoformer> model_;
  53. // Persistent Graph State
  54. struct ggml_context* ctx_ = nullptr;
  55. struct ggml_cgraph* gf_ = nullptr;
  56. struct ggml_gallocr* allocr_ = nullptr;
  57. // Cached Input Tensors (owned by ctx_)
  58. struct ggml_tensor* input_tensor_ = nullptr;
  59. struct ggml_tensor* pos_time_ = nullptr;
  60. struct ggml_tensor* pos_freq_ = nullptr;
  61. struct ggml_tensor* mask_out_tensor_ = nullptr;
  62. // Cached Host Data (to avoid reallocation)
  63. std::vector<int32_t> pos_time_data_;
  64. std::vector<int32_t> pos_freq_data_;
  65. // Current config state
  66. int cached_n_frames_ = -1;
  67. // Pipelined State Data
  68. struct ChunkState {
  69. int id = -1;
  70. std::vector<float> input_audio; // Original chunk audio
  71. std::vector<float> stft_flattened; // [Prepared Input for GPU]
  72. std::vector<std::vector<float>> stft_outputs; // Kept for reconstruction
  73. int n_frames = 0;
  74. std::vector<float> mask_output; // Output from GPU
  75. std::vector<std::vector<float>> final_audio; // Result after ISTFT [stems][samples]
  76. };
  77. // Helper to ensure graph is built for specific n_frames
  78. bool EnsureGraph(int n_frames);
  79. void ComputeSTFT(const std::vector<float>& input_audio,
  80. std::vector<std::vector<float>>& stft_outputs,
  81. int& n_frames);
  82. void PrepareModelInput(const std::vector<std::vector<float>>& stft_outputs,
  83. int n_frames,
  84. std::vector<float>& model_input_rearranged);
  85. void PostProcessAndISTFT(const std::vector<float>& mask_output,
  86. const std::vector<std::vector<float>>& stft_outputs,
  87. int n_frames,
  88. std::vector<std::vector<float>>& output_audio);
  89. // Pipeline Steps
  90. std::shared_ptr<ChunkState> PreProcessChunk(const std::vector<float>& chunk_audio, int id);
  91. void RunInference(std::shared_ptr<ChunkState> state);
  92. void PostProcessChunk(std::shared_ptr<ChunkState> state);
  93. };