model.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #pragma once
  2. #include <string>
  3. #include <vector>
  4. #include <memory>
  5. #include <ggml.h>
  6. #include <ggml-backend.h>
  7. #include <ggml-alloc.h>
  8. // Forward declarations
  9. struct ggml_context;
  10. struct ggml_cgraph;
  11. struct gguf_context;
  12. /**
  13. * MelBandRoformer Model
  14. *
  15. * This class handles:
  16. * 1. Loading weights from GGUF file
  17. * 2. Providing access to weights and buffers
  18. * 3. Building GGML computation graphs for each component
  19. *
  20. * Execution is handled by test/inference code using these graphs.
  21. */
  22. class MelBandRoformer {
  23. public:
  24. MelBandRoformer();
  25. ~MelBandRoformer();
  26. // Initialize model from GGUF file
  27. void Initialize(const std::string& model_path);
  28. // ========== Accessors for weights and config ==========
  29. // Get weight tensor by name
  30. ggml_tensor* GetWeight(const std::string& name) const;
  31. // Get backend
  32. ggml_backend_t GetBackend() const { return backend_; }
  33. // Get weights context (for creating tensors from weights)
  34. ggml_context* GetWeightsContext() const { return ctx_weights_; }
  35. // ========== Model Config Accessors ==========
  36. int GetDim() const { return dim_; }
  37. int GetDepth() const { return depth_; }
  38. int GetNumBands() const { return num_bands_; }
  39. int GetNFFT() const { return n_fft_; }
  40. int GetHopLength() const { return hop_length_; }
  41. int GetWinLength() const { return win_length_; }
  42. int GetNumStems() const { return num_stems_; }
  43. bool GetSkipConnection() const { return skip_connection_; }
  44. bool GetSTFTNormalized() const { return stft_normalized_; }
  45. bool GetZeroDC() const { return zero_dc_; }
  46. // Inference defaults (from GGUF, can be overridden at runtime)
  47. int GetDefaultChunkSize() const { return default_chunk_size_; }
  48. int GetDefaultNumOverlap() const { return default_num_overlap_; }
  49. // ========== Buffer Accessors ==========
  50. const std::vector<int>& GetFreqIndices() const { return freq_indices_; }
  51. const std::vector<int>& GetNumBandsPerFreq() const { return num_bands_per_freq_; }
  52. const std::vector<int>& GetNumFreqsPerBand() const { return num_freqs_per_band_; }
  53. // Calculate dim_inputs for each band (num_freqs * 4 for stereo complex)
  54. std::vector<int> GetDimInputs() const;
  55. int GetTotalDimInput() const;
  56. // ========== Graph Building Functions ==========
  57. // These functions build GGML computation graph nodes.
  58. // They don't execute - execution is done by caller with gallocr + backend_graph_compute.
  59. /**
  60. * Build BandSplit subgraph
  61. * @param ctx Computation context (must have no_alloc=true)
  62. * @param input Input tensor [total_dim_input, n_frames, batch]
  63. * @param gf Graph to add nodes to
  64. * @return Output tensor [dim, num_bands, n_frames, batch]
  65. */
  66. ggml_tensor* BuildBandSplitGraph(
  67. ggml_context* ctx,
  68. ggml_tensor* input,
  69. ggml_cgraph* gf,
  70. int n_frames,
  71. int batch = 1
  72. );
  73. /**
  74. * Build Transformer layers subgraph (Time + Freq transformers)
  75. * @param ctx Computation context
  76. * @param input Input tensor [dim, num_bands, n_frames, batch]
  77. * @param gf Graph to add nodes to
  78. * @param pos_time_exp Expanded position tensor for time RoPE [T * F * B], with repeating [0..T-1] * (F*B) times
  79. * @param pos_freq_exp Expanded position tensor for freq RoPE [F * T * B], with repeating [0..F-1] * (T*B) times
  80. * @return Output tensor [dim, num_bands, n_frames, batch]
  81. */
  82. ggml_tensor* BuildTransformersGraph(
  83. ggml_context* ctx,
  84. ggml_tensor* input,
  85. ggml_cgraph* gf,
  86. ggml_tensor* pos_time_exp,
  87. ggml_tensor* pos_freq_exp,
  88. int n_frames,
  89. int batch = 1
  90. );
  91. /**
  92. * Build MaskEstimator subgraph
  93. * @param ctx Computation context
  94. * @param input Input tensor [dim, num_bands, n_frames, batch]
  95. * @param gf Graph to add nodes to
  96. * @return Output tensor [total_mask_dim, n_frames, batch]
  97. */
  98. ggml_tensor* BuildMaskEstimatorGraph(
  99. ggml_context* ctx,
  100. ggml_tensor* input,
  101. ggml_cgraph* gf,
  102. int n_frames,
  103. int batch = 1
  104. );
  105. private:
  106. // GGML Contexts
  107. ggml_context* ctx_weights_ = nullptr;
  108. // Backend
  109. ggml_backend_t backend_ = nullptr;
  110. ggml_backend_buffer_t buffer_weights_ = nullptr;
  111. // Model Config
  112. int dim_ = 384;
  113. int depth_ = 6;
  114. int num_bands_ = 60;
  115. int heads_ = 8;
  116. int dim_head_ = 64;
  117. int n_fft_ = 2048;
  118. int hop_length_ = 441;
  119. int win_length_ = 2048;
  120. // New Params
  121. int num_stems_ = 1;
  122. bool skip_connection_ = false;
  123. bool stft_normalized_ = false;
  124. bool zero_dc_ = false;
  125. int mask_estimator_depth_ = 1;
  126. // Inference defaults
  127. int default_chunk_size_ = 352800;
  128. int default_num_overlap_ = 2;
  129. // Buffers loaded from GGUF
  130. std::vector<int> freq_indices_;
  131. std::vector<int> num_bands_per_freq_;
  132. std::vector<int> num_freqs_per_band_;
  133. // Helper to load GGUF
  134. void LoadWeights(const std::string& path);
  135. };