model.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #pragma once
  2. #include <string>
  3. #include <vector>
  4. #include <memory>
  5. #include <ggml.h>
  6. #include <ggml-backend.h>
  7. #include <ggml-alloc.h>
  8. // Forward declarations
  9. struct ggml_context;
  10. struct ggml_cgraph;
  11. struct gguf_context;
  12. /**
  13. * MelBandRoformer Model
  14. *
  15. * This class handles:
  16. * 1. Loading weights from GGUF file
  17. * 2. Providing access to weights and buffers
  18. * 3. Building GGML computation graphs for each component
  19. *
  20. * Execution is handled by test/inference code using these graphs.
  21. */
  22. class MelBandRoformer {
  23. public:
  24. MelBandRoformer();
  25. ~MelBandRoformer();
  26. // Initialize model from GGUF file
  27. void Initialize(const std::string& model_path);
  28. // ========== Accessors for weights and config ==========
  29. // Get weight tensor by name
  30. ggml_tensor* GetWeight(const std::string& name) const;
  31. // Get backend
  32. ggml_backend_t GetBackend() const { return backend_; }
  33. // Get weights context (for creating tensors from weights)
  34. ggml_context* GetWeightsContext() const { return ctx_weights_; }
  35. // ========== Model Config Accessors ==========
  36. int GetDim() const { return dim_; }
  37. int GetDepth() const { return depth_; }
  38. int GetNumBands() const { return num_bands_; }
  39. int GetNFFT() const { return n_fft_; }
  40. int GetHopLength() const { return hop_length_; }
  41. int GetWinLength() const { return win_length_; }
  42. int GetNumStems() const { return num_stems_; }
  43. bool GetSkipConnection() const { return skip_connection_; }
  44. bool GetSTFTNormalized() const { return stft_normalized_; }
  45. int GetZeroDC() const { return zero_dc_; }
  46. int GetSampleRate() const { return sample_rate_; }
  47. // Inference defaults (from GGUF, can be overridden at runtime)
  48. int GetDefaultChunkSize() const { return default_chunk_size_; }
  49. int GetDefaultNumOverlap() const { return default_num_overlap_; }
  50. // ========== Buffer Accessors ==========
  51. const std::vector<int>& GetFreqIndices() const { return freq_indices_; }
  52. const std::vector<int>& GetNumBandsPerFreq() const { return num_bands_per_freq_; }
  53. const std::vector<int>& GetNumFreqsPerBand() const { return num_freqs_per_band_; }
  54. // Calculate dim_inputs for each band (num_freqs * 4 for stereo complex)
  55. std::vector<int> GetDimInputs() const;
  56. int GetTotalDimInput() const;
  57. // ========== Graph Building Functions ==========
  58. // These functions build GGML computation graph nodes.
  59. // They don't execute - execution is done by caller with gallocr + backend_graph_compute.
  60. /**
  61. * Build BandSplit subgraph
  62. * @param ctx Computation context (must have no_alloc=true)
  63. * @param input Input tensor [total_dim_input, n_frames, batch]
  64. * @param gf Graph to add nodes to
  65. * @return Output tensor [dim, num_bands, n_frames, batch]
  66. */
  67. ggml_tensor* BuildBandSplitGraph(
  68. ggml_context* ctx,
  69. ggml_tensor* input,
  70. ggml_cgraph* gf,
  71. int n_frames,
  72. int batch = 1
  73. );
  74. /**
  75. * Build Transformer layers subgraph (Time + Freq transformers)
  76. * @param ctx Computation context
  77. * @param input Input tensor [dim, num_bands, n_frames, batch]
  78. * @param gf Graph to add nodes to
  79. * @param pos_time_exp Expanded position tensor for time RoPE [T * F * B], with repeating [0..T-1] * (F*B) times
  80. * @param pos_freq_exp Expanded position tensor for freq RoPE [F * T * B], with repeating [0..F-1] * (T*B) times
  81. * @return Output tensor [dim, num_bands, n_frames, batch]
  82. */
  83. ggml_tensor* BuildTransformersGraph(
  84. ggml_context* ctx,
  85. ggml_tensor* input,
  86. ggml_cgraph* gf,
  87. ggml_tensor* pos_time_exp,
  88. ggml_tensor* pos_freq_exp,
  89. int n_frames,
  90. int batch = 1
  91. );
  92. /**
  93. * Build MaskEstimator subgraph
  94. * @param ctx Computation context
  95. * @param input Input tensor [dim, num_bands, n_frames, batch]
  96. * @param gf Graph to add nodes to
  97. * @return Output tensor [total_mask_dim, n_frames, batch]
  98. */
  99. ggml_tensor* BuildMaskEstimatorGraph(
  100. ggml_context* ctx,
  101. ggml_tensor* input,
  102. ggml_cgraph* gf,
  103. int n_frames,
  104. int batch = 1
  105. );
  106. private:
  107. // GGML Contexts
  108. ggml_context* ctx_weights_ = nullptr;
  109. // Backend
  110. ggml_backend_t backend_ = nullptr;
  111. ggml_backend_buffer_t buffer_weights_ = nullptr;
  112. // Model Config
  113. int dim_ = 384;
  114. int depth_ = 6;
  115. int num_bands_ = 60;
  116. int heads_ = 8;
  117. int dim_head_ = 64;
  118. int n_fft_ = 2048;
  119. int hop_length_ = 441;
  120. int win_length_ = 2048;
  121. // New Params
  122. int num_stems_ = 1;
  123. bool skip_connection_ = false;
  124. bool stft_normalized_ = false;
  125. bool zero_dc_ = false;
  126. int mask_estimator_depth_ = 1;
  127. int sample_rate_ = 44100;
  128. // Inference defaults
  129. int default_chunk_size_ = 352800;
  130. int default_num_overlap_ = 2;
  131. // Buffers loaded from GGUF
  132. std::vector<int> freq_indices_;
  133. std::vector<int> num_bands_per_freq_;
  134. std::vector<int> num_freqs_per_band_;
  135. // Helper to load GGUF
  136. void LoadWeights(const std::string& path);
  137. };