| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739 |
- #include "model.h"
- #include <ggml.h>
- #include <ggml-alloc.h>
- #include <ggml-backend.h>
- #include <gguf.h>
- #include <iostream>
- #include <stdexcept>
- #include <cstring>
- #include <cmath>
- MelBandRoformer::MelBandRoformer() {
- }
- MelBandRoformer::~MelBandRoformer() {
- if (buffer_weights_) ggml_backend_buffer_free(buffer_weights_);
- if (backend_) ggml_backend_free(backend_);
- if (ctx_weights_) ggml_free(ctx_weights_);
- }
- void MelBandRoformer::Initialize(const std::string& model_path) {
- // Use best available backend, but allow forcing CPU
- if (std::getenv("MBR_FORCE_CPU")) {
- backend_ = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
- } else {
- backend_ = ggml_backend_init_best();
- }
-
- if (!backend_) {
- throw std::runtime_error("Failed to initialize backend");
- }
- std::cout << "Using backend: " << ggml_backend_name(backend_) << std::endl;
- LoadWeights(model_path);
- }
- void MelBandRoformer::LoadWeights(const std::string& path) {
- std::cout << "Loading model from " << path << std::endl;
- struct gguf_init_params params = {
- /*.no_alloc = */ true,
- /*.ctx = */ &ctx_weights_,
- };
- struct gguf_context* ctx_gguf = gguf_init_from_file(path.c_str(), params);
- if (!ctx_gguf) {
- throw std::runtime_error("Failed to load GGUF file: " + path);
- }
- // 1. Read Hyperparameters
- int kv_idx;
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_n_fft");
- if (kv_idx >= 0) n_fft_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_hop_length");
- if (kv_idx >= 0) hop_length_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_win_length");
- if (kv_idx >= 0) win_length_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.dim");
- if (kv_idx >= 0) dim_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.num_bands");
- if (kv_idx >= 0) num_bands_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.depth");
- if (kv_idx >= 0) depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
- // New Parameters
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.num_stems");
- if (kv_idx >= 0) num_stems_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.skip_connection");
- if (kv_idx >= 0) skip_connection_ = gguf_get_val_bool(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_normalized");
- if (kv_idx >= 0) stft_normalized_ = gguf_get_val_bool(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.zero_dc");
- if (kv_idx >= 0) zero_dc_ = gguf_get_val_bool(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.mask_estimator_depth");
- if (kv_idx >= 0) mask_estimator_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.sample_rate");
- if (kv_idx >= 0) sample_rate_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- // Inference defaults (optional, fallback to hardcoded values)
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.default_chunk_size");
- if (kv_idx >= 0) default_chunk_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.default_num_overlap");
- if (kv_idx >= 0) default_num_overlap_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
-
- kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.linear_transformer_depth");
- if (kv_idx >= 0) {
- int lin_depth = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
- if (lin_depth > 0) {
- std::cerr << "\n[WARNING] Model uses Linear Attention (depth=" << lin_depth
- << "). This is NOT supported yet. Results will be incorrect.\n" << std::endl;
- }
- }
- std::cout << "Model Config: n_fft=" << n_fft_ << ", hop_length=" << hop_length_
- << ", num_bands=" << num_bands_ << ", dim=" << dim_ << ", depth=" << depth_
- << ", num_stems=" << num_stems_ << ", skip_conn=" << skip_connection_ << std::endl;
- std::cout << "Inference Defaults: chunk_size=" << default_chunk_size_
- << ", num_overlap=" << default_num_overlap_ << std::endl;
- // 2. Allocate backend buffer for ALL tensors
- buffer_weights_ = ggml_backend_alloc_ctx_tensors_from_buft(
- ctx_weights_,
- ggml_backend_get_default_buffer_type(backend_)
- );
- if (!buffer_weights_) {
- throw std::runtime_error("Failed to allocate weight buffer");
- }
- // 3. Read data from file and upload to backend
- FILE* file = fopen(path.c_str(), "rb");
- if (!file) throw std::runtime_error("Cannot open file");
-
- size_t data_offset = gguf_get_data_offset(ctx_gguf);
-
- struct ggml_tensor* t = ggml_get_first_tensor(ctx_weights_);
- std::vector<uint8_t> read_buf;
-
- while (t) {
- int tid = gguf_find_tensor(ctx_gguf, t->name);
- if (tid >= 0) {
- size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf, tid);
- size_t size = ggml_nbytes(t);
-
- if (read_buf.size() < size) read_buf.resize(size);
-
- fseek(file, (long)offset, SEEK_SET);
- fread(read_buf.data(), 1, size, file);
-
- // Upload to backend
- ggml_backend_tensor_set(t, read_buf.data(), 0, size);
-
- // Cache important buffers
- if (std::string(t->name) == "buffer_freq_indices") {
- freq_indices_.resize(ggml_nelements(t));
- if (t->type == GGML_TYPE_I32) {
- memcpy(freq_indices_.data(), read_buf.data(), size);
- }
- std::cout << " Loaded freq_indices: " << freq_indices_.size() << " indices" << std::endl;
- }
- if (std::string(t->name) == "buffer_num_bands_per_freq") {
- num_bands_per_freq_.resize(ggml_nelements(t));
- if (t->type == GGML_TYPE_I32) {
- memcpy(num_bands_per_freq_.data(), read_buf.data(), size);
- }
- }
- if (std::string(t->name) == "buffer_num_freqs_per_band") {
- num_freqs_per_band_.resize(ggml_nelements(t));
- if (t->type == GGML_TYPE_I32) {
- memcpy(num_freqs_per_band_.data(), read_buf.data(), size);
- }
- }
- }
-
- t = ggml_get_next_tensor(ctx_weights_, t);
- }
-
- fclose(file);
-
- int n_tensors = gguf_get_n_tensors(ctx_gguf);
- std::cout << "Loaded " << n_tensors << " tensors" << std::endl;
-
- gguf_free(ctx_gguf);
- }
- ggml_tensor* MelBandRoformer::GetWeight(const std::string& name) const {
- return ggml_get_tensor(ctx_weights_, name.c_str());
- }
- std::vector<int> MelBandRoformer::GetDimInputs() const {
- std::vector<int> dim_inputs(num_bands_);
- for (int i = 0; i < num_bands_; ++i) {
- int num_freqs = num_freqs_per_band_[i];
- dim_inputs[i] = num_freqs * 4; // stereo=2, complex=2
- }
- return dim_inputs;
- }
- int MelBandRoformer::GetTotalDimInput() const {
- int total = 0;
- for (int i = 0; i < num_bands_; ++i) {
- total += num_freqs_per_band_[i] * 4;
- }
- return total;
- }
- // ========== Graph Building Functions ==========
- ggml_tensor* MelBandRoformer::BuildBandSplitGraph(
- ggml_context* ctx,
- ggml_tensor* input,
- ggml_cgraph* gf,
- int n_frames,
- int batch
- ) {
- // Following test_10_full_model.cpp implementation
- // Input: [total_dim_input, n_frames, batch]
- // Output: [dim, num_bands, n_frames, batch]
-
- std::vector<int> dim_inputs = GetDimInputs();
-
- ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, dim_, num_bands_, n_frames, batch);
-
- size_t offset_elements = 0;
- for (int i = 0; i < num_bands_; ++i) {
- int dim_in = dim_inputs[i];
-
- // View for this band's input
- ggml_tensor* band_input = ggml_view_3d(ctx, input,
- dim_in, n_frames, batch,
- input->nb[1], input->nb[2],
- offset_elements * sizeof(float));
-
- // Get RMSNorm gamma weight
- // band_split.{i}.norm.weight
- std::string gamma_name = "band_split." + std::to_string(i) + ".norm.weight";
- ggml_tensor* gamma = GetWeight(gamma_name);
- if (!gamma) {
- std::cerr << "Missing weight: " << gamma_name << std::endl;
- return nullptr;
- }
-
- // RMSNorm
- ggml_tensor* normed = ggml_rms_norm(ctx, band_input, 1e-12f);
- normed = ggml_mul(ctx, normed, gamma);
-
- // Get Linear weight and bias
- // band_split.{i}.linear.weight
- std::string w_name = "band_split." + std::to_string(i) + ".linear.weight";
- std::string b_name = "band_split." + std::to_string(i) + ".linear.bias";
- ggml_tensor* weight = GetWeight(w_name);
- ggml_tensor* bias = GetWeight(b_name);
-
- if (!weight || !bias) {
- std::cerr << "Missing weight: " << w_name << " or " << b_name << std::endl;
- return nullptr;
- }
-
- // Linear projection
- ggml_tensor* projected = ggml_mul_mat(ctx, weight, normed);
- projected = ggml_add(ctx, projected, bias);
-
- // Copy to output slice
- ggml_tensor* out_slice = ggml_view_3d(ctx, x,
- dim_, n_frames, batch,
- x->nb[2], x->nb[3],
- i * x->nb[1]);
-
- ggml_build_forward_expand(gf, ggml_cpy(ctx, projected, out_slice));
-
- offset_elements += dim_in;
- }
-
- return x;
- }
- ggml_tensor* MelBandRoformer::BuildTransformersGraph(
- ggml_context* ctx,
- ggml_tensor* input,
- ggml_cgraph* gf,
- ggml_tensor* pos_time_exp,
- ggml_tensor* pos_freq_exp,
- int n_frames,
- int batch
- ) {
- // Following test_10_full_model.cpp implementation
- // Input: [dim, num_bands, n_frames, batch]
-
- const int D = dim_;
- const int F = num_bands_;
- const int T = n_frames;
- const int B = batch;
- const int HEADS = heads_;
- const int DIM_HEAD = dim_head_;
- const int DIM_INNER = HEADS * DIM_HEAD;
-
- ggml_tensor* x = input;
- std::vector<ggml_tensor*> skip_outputs;
-
- for (int layer = 0; layer < depth_; ++layer) {
- if (skip_connection_) {
- for (ggml_tensor* s : skip_outputs) {
- x = ggml_add(ctx, x, s);
- }
- }
- // ========== TIME TRANSFORMER ==========
- // Permute: [D, F, T, B] -> [D, T, F, B]
- x = ggml_permute(ctx, x, 0, 2, 1, 3);
- x = ggml_cont(ctx, x);
-
- int fb = F * B;
- ggml_tensor* x_packed = ggml_reshape_3d(ctx, x, D, T, fb);
-
- std::string time_prefix = "blk." + std::to_string(layer) + ".time_attn";
- std::string time_ff_prefix = "blk." + std::to_string(layer) + ".time_ff";
-
- // Attention Block
- // blk.{l}.time_attn_norm.weight
- ggml_tensor* t_attn_norm_w = GetWeight(time_prefix + "_norm.weight");
- if (!t_attn_norm_w) { std::cerr << "Missing: " << time_prefix << "_norm.weight\n"; return nullptr; }
-
- ggml_tensor* x_norm = ggml_rms_norm(ctx, x_packed, 1e-12f);
- x_norm = ggml_mul(ctx, x_norm, t_attn_norm_w);
-
- // blk.{l}.time_attn_qkv.weight
- ggml_tensor* t_qkv_w = GetWeight(time_prefix + "_qkv.weight");
- if (!t_qkv_w) { std::cerr << "Missing: " << time_prefix << "_qkv.weight\n"; return nullptr; }
-
- ggml_tensor* qkv_out = ggml_mul_mat(ctx, t_qkv_w, x_norm);
-
- // Split QKV
- ggml_tensor* Q_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb,
- qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], 0);
- ggml_tensor* K_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb,
- qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], DIM_INNER*sizeof(float));
- ggml_tensor* V_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb,
- qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], 2*DIM_INNER*sizeof(float));
-
- ggml_tensor* Q = ggml_cont(ctx, Q_view);
- ggml_tensor* K = ggml_cont(ctx, K_view);
- ggml_tensor* V = ggml_cont(ctx, V_view);
-
- // RoPE with CUDA-compatible reshape
- // Original Q/K shape: [DIM_HEAD, T, HEADS, fb]
- // After permute: [DIM_HEAD, HEADS, T, fb]
- // For CUDA RoPE: reshape to [DIM_HEAD, HEADS, T*fb, 1] and use expanded pos
- ggml_tensor* Q_perm = ggml_permute(ctx, Q, 0, 2, 1, 3);
- ggml_tensor* K_perm = ggml_permute(ctx, K, 0, 2, 1, 3);
- ggml_tensor* Q_perm_cont = ggml_cont(ctx, Q_perm);
- ggml_tensor* K_perm_cont = ggml_cont(ctx, K_perm);
-
- // Reshape to merge batch(fb) into sequence for CUDA RoPE compatibility
- int T_fb = T * fb;
- ggml_tensor* Q_flat = ggml_reshape_4d(ctx, Q_perm_cont, DIM_HEAD, HEADS, T_fb, 1);
- ggml_tensor* K_flat = ggml_reshape_4d(ctx, K_perm_cont, DIM_HEAD, HEADS, T_fb, 1);
-
- // Use passed-in expanded position tensor (caller prepares [T*F*B] with repeating [0..T-1])
- ggml_tensor* Q_rope_flat = ggml_rope_ext(ctx, Q_flat, pos_time_exp, nullptr, DIM_HEAD,
- GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
- ggml_tensor* K_rope_flat = ggml_rope_ext(ctx, K_flat, pos_time_exp, nullptr, DIM_HEAD,
- GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
-
- // Reshape back to [DIM_HEAD, HEADS, T, fb]
- ggml_tensor* Q_rope_perm = ggml_reshape_4d(ctx, Q_rope_flat, DIM_HEAD, HEADS, T, fb);
- ggml_tensor* K_rope_perm = ggml_reshape_4d(ctx, K_rope_flat, DIM_HEAD, HEADS, T, fb);
-
- ggml_tensor* Q_rope = ggml_permute(ctx, Q_rope_perm, 0, 2, 1, 3);
- ggml_tensor* K_rope = ggml_permute(ctx, K_rope_perm, 0, 2, 1, 3);
- // Flash Attention
- // Inputs: [DIM_HEAD, T, HEADS, fb]
- // Output: [DIM_HEAD, HEADS, T, fb] (permuted)
- ggml_tensor* Q_fa = ggml_cont(ctx, Q_rope);
- ggml_tensor* K_fa = ggml_cont(ctx, K_rope);
- ggml_tensor* V_fa = V; // V is already contiguous [DIM_HEAD, T, HEADS, fb]
- float scale = 1.0f / sqrtf(static_cast<float>(DIM_HEAD));
- ggml_tensor* attn_out_fa = ggml_flash_attn_ext(ctx, Q_fa, K_fa, V_fa, nullptr, scale, 0.0f, 0.0f);
-
- // Permute back to [DIM_HEAD, T, HEADS, fb] to match original flow
- ggml_tensor* attn_out_perm = ggml_permute(ctx, attn_out_fa, 0, 2, 1, 3);
- ggml_tensor* attn_out_raw = ggml_cont(ctx, attn_out_perm);
-
- // Gates
- // blk.{l}.time_attn_gate.weight/bias
- ggml_tensor* t_gate_w = GetWeight(time_prefix + "_gate.weight");
- ggml_tensor* t_gate_b = GetWeight(time_prefix + "_gate.bias");
- if (!t_gate_w || !t_gate_b) { std::cerr << "Missing gates weights\n"; return nullptr; }
-
- ggml_tensor* gates = ggml_mul_mat(ctx, t_gate_w, x_norm);
- gates = ggml_add(ctx, gates, t_gate_b);
- gates = ggml_sigmoid(ctx, gates);
-
- ggml_tensor* gates_perm = ggml_permute(ctx, gates, 1, 0, 2, 3);
- ggml_tensor* gates_bcast = ggml_view_4d(ctx, gates_perm, 1, T, HEADS, fb,
- gates_perm->nb[0], gates_perm->nb[1], gates_perm->nb[2], 0);
-
- ggml_tensor* gated_out = ggml_mul(ctx, attn_out_raw, gates_bcast);
-
- ggml_tensor* out_perm = ggml_permute(ctx, gated_out, 0, 2, 1, 3);
- ggml_tensor* out_cont = ggml_cont(ctx, out_perm);
- ggml_tensor* out_flat = ggml_reshape_3d(ctx, out_cont, DIM_INNER, T, fb);
-
- // blk.{l}.time_attn_out.weight
- ggml_tensor* t_attn_out_w = GetWeight(time_prefix + "_out.weight");
- if (!t_attn_out_w) { std::cerr << "Missing to_out_weight\n"; return nullptr; }
-
- ggml_tensor* attn_block_out = ggml_mul_mat(ctx, t_attn_out_w, out_flat);
- ggml_tensor* x_resid1 = ggml_add(ctx, x_packed, attn_block_out);
-
- // FeedForward Block
- // blk.{l}.time_ff_norm.weight
- ggml_tensor* t_ff_norm_w = GetWeight(time_ff_prefix + "_norm.weight");
- if (!t_ff_norm_w) { std::cerr << "Missing ff norm\n"; return nullptr; }
-
- ggml_tensor* x_resid1_norm = ggml_rms_norm(ctx, x_resid1, 1e-12f);
- x_resid1_norm = ggml_mul(ctx, x_resid1_norm, t_ff_norm_w);
-
- // blk.{l}.time_ff_in.weight/bias
- ggml_tensor* t_ff_in_w = GetWeight(time_ff_prefix + "_in.weight");
- ggml_tensor* t_ff_in_b = GetWeight(time_ff_prefix + "_in.bias");
- if (!t_ff_in_w || !t_ff_in_b) { std::cerr << "Missing ff in weights\n"; return nullptr; }
-
- ggml_tensor* ff_proj_in = ggml_mul_mat(ctx, t_ff_in_w, x_resid1_norm);
- ff_proj_in = ggml_add(ctx, ff_proj_in, t_ff_in_b);
-
- ggml_tensor* gelu_out = ggml_gelu_erf(ctx, ff_proj_in);
-
-
- // blk.{l}.time_ff_out.weight/bias
- ggml_tensor* t_ff_out_w = GetWeight(time_ff_prefix + "_out.weight");
- ggml_tensor* t_ff_out_b = GetWeight(time_ff_prefix + "_out.bias");
- if (!t_ff_out_w || !t_ff_out_b) { std::cerr << "Missing ff out weights\n"; return nullptr; }
-
- ggml_tensor* ff_block_out = ggml_mul_mat(ctx, t_ff_out_w, gelu_out);
- ff_block_out = ggml_add(ctx, ff_block_out, t_ff_out_b);
-
- x_packed = ggml_add(ctx, x_resid1, ff_block_out);
-
-
- // Time Transformer Final Norm
- // blk.{l}.time_norm.weight
- std::string time_norm_name = "blk." + std::to_string(layer) + ".time_norm.weight";
- ggml_tensor* time_norm_w = GetWeight(time_norm_name);
- if (!time_norm_w) { std::cerr << "Missing: " << time_norm_name << "\n"; return nullptr; }
-
- x_packed = ggml_rms_norm(ctx, x_packed, 1e-12f);
- x_packed = ggml_mul(ctx, x_packed, time_norm_w);
-
- x = ggml_reshape_4d(ctx, x_packed, D, T, F, B);
- x = ggml_permute(ctx, x, 0, 2, 1, 3);
- x = ggml_cont(ctx, x);
-
- // ========== FREQ TRANSFORMER ==========
- int tb = T * B;
- ggml_tensor* x_freq_packed = ggml_reshape_3d(ctx, x, D, F, tb);
-
-
- std::string freq_prefix = "blk." + std::to_string(layer) + ".freq_attn";
- std::string freq_ff_prefix = "blk." + std::to_string(layer) + ".freq_ff";
-
- ggml_tensor* f_attn_norm_w = GetWeight(freq_prefix + "_norm.weight");
- if (!f_attn_norm_w) { std::cerr << "Missing freq norm\n"; return nullptr; }
-
- ggml_tensor* x_fnorm = ggml_rms_norm(ctx, x_freq_packed, 1e-12f);
- x_fnorm = ggml_mul(ctx, x_fnorm, f_attn_norm_w);
-
-
- ggml_tensor* f_qkv_w = GetWeight(freq_prefix + "_qkv.weight");
- if (!f_qkv_w) { std::cerr << "Missing freq qkv\n"; return nullptr; }
-
- ggml_tensor* f_qkv_out = ggml_mul_mat(ctx, f_qkv_w, x_fnorm);
-
- ggml_tensor* fQ_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb,
- f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], 0);
- ggml_tensor* fK_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb,
- f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], DIM_INNER*sizeof(float));
- ggml_tensor* fV_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb,
- f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], 2*DIM_INNER*sizeof(float));
-
- ggml_tensor* fQ = ggml_cont(ctx, fQ_view);
- ggml_tensor* fK = ggml_cont(ctx, fK_view);
- ggml_tensor* fV = ggml_cont(ctx, fV_view);
-
- // RoPE with CUDA-compatible reshape for Freq Transformer
- // fQ/fK shape after permute: [DIM_HEAD, HEADS, F, tb]
- ggml_tensor* fQ_perm = ggml_permute(ctx, fQ, 0, 2, 1, 3);
- ggml_tensor* fK_perm = ggml_permute(ctx, fK, 0, 2, 1, 3);
- ggml_tensor* fQ_perm_cont = ggml_cont(ctx, fQ_perm);
- ggml_tensor* fK_perm_cont = ggml_cont(ctx, fK_perm);
-
- // Reshape to merge batch(tb) into sequence for CUDA RoPE compatibility
- int F_tb = F * tb;
- ggml_tensor* fQ_flat = ggml_reshape_4d(ctx, fQ_perm_cont, DIM_HEAD, HEADS, F_tb, 1);
- ggml_tensor* fK_flat = ggml_reshape_4d(ctx, fK_perm_cont, DIM_HEAD, HEADS, F_tb, 1);
-
- // Use passed-in expanded position tensor (caller prepares [F*T*B] with repeating [0..F-1])
- ggml_tensor* fQ_rope_flat = ggml_rope_ext(ctx, fQ_flat, pos_freq_exp, nullptr, DIM_HEAD,
- GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
- ggml_tensor* fK_rope_flat = ggml_rope_ext(ctx, fK_flat, pos_freq_exp, nullptr, DIM_HEAD,
- GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
-
- // Reshape back to [DIM_HEAD, HEADS, F, tb]
- ggml_tensor* fQ_rope_perm = ggml_reshape_4d(ctx, fQ_rope_flat, DIM_HEAD, HEADS, F, tb);
- ggml_tensor* fK_rope_perm = ggml_reshape_4d(ctx, fK_rope_flat, DIM_HEAD, HEADS, F, tb);
-
- ggml_tensor* fQ_rope = ggml_permute(ctx, fQ_rope_perm, 0, 2, 1, 3);
- ggml_tensor* fK_rope = ggml_permute(ctx, fK_rope_perm, 0, 2, 1, 3);
-
- // Flash Attention (Freq)
- // Inputs: [DIM_HEAD, F, HEADS, tb]
- ggml_tensor* fQ_fa = ggml_cont(ctx, fQ_rope);
- ggml_tensor* fK_fa = ggml_cont(ctx, fK_rope);
- ggml_tensor* fV_fa = fV; // fV is contiguous [DIM_HEAD, F, HEADS, tb]
- // float scale is already defined in scope (Time Transformer block) or re-define if shadowed loop?
- // Actually 'scale' was defined inside the Time Transformer loop, so it persists?
- // No, Freq Transformer is in the same loop logic?
- // Let's check scope. It's in the same 'layer' loop.
- // But previously I removed the definition line in Time Transformer too? No, I added it back above.
- // Wait, best to redeclare or rely on scope?
- // Time Transformer code block vs Freq Transformer.
- // Let's just use the value.
- // Re-reading Freq Block:
- // Need to be safe. Redefine 'scale' if needed or ensuring it's available.
- // Previous search showed `float scale` was defined in Time Block.
- // If Time block is just sequential code, `scale` is available.
- // But I removed the line in Time block in the previous step (lines 307-319 replaced).
- // So I need to add it back in Time block (done in chunk 1).
- // For Freq block, if it's in same scope, it's fine.
- // However, standard good practice:
-
- // float scale = 1.0f / sqrtf(static_cast<float>(DIM_HEAD)); // Redefinition might error if same scope.
- // Let's check file content to see if Freq block is separately scoped.
- // It's in `for (int layer...) { ... Time ... Freq ... }`.
- // So `scale` defined in Time part is visible in Freq part.
- // So I don't need to define it again, just ensure it IS defined in Time part.
-
- ggml_tensor* f_attn_out_fa = ggml_flash_attn_ext(ctx, fQ_fa, fK_fa, fV_fa, nullptr, scale, 0.0f, 0.0f);
- // Permute output back to [DIM_HEAD, F, HEADS, tb]
- ggml_tensor* f_attn_out_perm = ggml_permute(ctx, f_attn_out_fa, 0, 2, 1, 3);
- ggml_tensor* f_attn_out_raw = ggml_cont(ctx, f_attn_out_perm);
-
-
- ggml_tensor* f_gate_w = GetWeight(freq_prefix + "_gate.weight");
- ggml_tensor* f_gate_b = GetWeight(freq_prefix + "_gate.bias");
- if (!f_gate_w || !f_gate_b) { std::cerr << "Missing freq gates\n"; return nullptr; }
-
- ggml_tensor* f_gates = ggml_mul_mat(ctx, f_gate_w, x_fnorm);
- f_gates = ggml_add(ctx, f_gates, f_gate_b);
- f_gates = ggml_sigmoid(ctx, f_gates);
-
- ggml_tensor* f_gates_perm = ggml_permute(ctx, f_gates, 1, 0, 2, 3);
- ggml_tensor* f_gates_bcast = ggml_view_4d(ctx, f_gates_perm, 1, F, HEADS, tb,
- f_gates_perm->nb[0], f_gates_perm->nb[1], f_gates_perm->nb[2], 0);
-
- ggml_tensor* f_gated_out = ggml_mul(ctx, f_attn_out_raw, f_gates_bcast);
-
- ggml_tensor* f_out_perm = ggml_permute(ctx, f_gated_out, 0, 2, 1, 3);
- ggml_tensor* f_out_cont = ggml_cont(ctx, f_out_perm);
- ggml_tensor* f_out_flat = ggml_reshape_3d(ctx, f_out_cont, DIM_INNER, F, tb);
-
-
- ggml_tensor* f_attn_out_w = GetWeight(freq_prefix + "_out.weight");
- if (!f_attn_out_w) { std::cerr << "Missing freq to_out\n"; return nullptr; }
-
- ggml_tensor* f_attn_block_out = ggml_mul_mat(ctx, f_attn_out_w, f_out_flat);
- ggml_tensor* f_x_resid1 = ggml_add(ctx, x_freq_packed, f_attn_block_out);
-
-
- // Freq FeedForward
- ggml_tensor* f_ff_norm_w = GetWeight(freq_ff_prefix + "_norm.weight");
- if (!f_ff_norm_w) { std::cerr << "Missing freq ff norm\n"; return nullptr; }
-
- ggml_tensor* f_x_resid1_norm = ggml_rms_norm(ctx, f_x_resid1, 1e-12f);
- f_x_resid1_norm = ggml_mul(ctx, f_x_resid1_norm, f_ff_norm_w);
-
- x_fnorm = ggml_mul(ctx, x_fnorm, f_attn_norm_w);
-
- ggml_tensor* f_ff_in_w = GetWeight(freq_ff_prefix + "_in.weight");
- ggml_tensor* f_ff_in_b = GetWeight(freq_ff_prefix + "_in.bias");
- if (!f_ff_in_w || !f_ff_in_b) { std::cerr << "Missing freq ff in\n"; return nullptr; }
-
- ggml_tensor* f_ff_proj_in = ggml_mul_mat(ctx, f_ff_in_w, f_x_resid1_norm);
- f_ff_proj_in = ggml_add(ctx, f_ff_proj_in, f_ff_in_b);
-
- ggml_tensor* f_gelu_out = ggml_gelu_erf(ctx, f_ff_proj_in);
-
-
- ggml_tensor* f_ff_out_w = GetWeight(freq_ff_prefix + "_out.weight");
- ggml_tensor* f_ff_out_b = GetWeight(freq_ff_prefix + "_out.bias");
- if (!f_ff_out_w || !f_ff_out_b) { std::cerr << "Missing freq ff out\n"; return nullptr; }
-
- ggml_tensor* f_ff_block_out = ggml_mul_mat(ctx, f_ff_out_w, f_gelu_out);
- f_ff_block_out = ggml_add(ctx, f_ff_block_out, f_ff_out_b);
-
- x_freq_packed = ggml_add(ctx, f_x_resid1, f_ff_block_out);
-
- // Freq Transformer Final Norm
- // blk.{l}.freq_norm.weight
- std::string freq_norm_name = "blk." + std::to_string(layer) + ".freq_norm.weight";
- ggml_tensor* freq_norm_w = GetWeight(freq_norm_name);
- if (!freq_norm_w) { std::cerr << "Missing: " << freq_norm_name << "\n"; return nullptr; }
-
- x_freq_packed = ggml_rms_norm(ctx, x_freq_packed, 1e-12f);
- x_freq_packed = ggml_mul(ctx, x_freq_packed, freq_norm_w);
-
- x = ggml_reshape_4d(ctx, x_freq_packed, D, F, T, B);
-
- if (skip_connection_) {
- skip_outputs.push_back(x);
- }
- }
-
- return x;
- }
- ggml_tensor* MelBandRoformer::BuildMaskEstimatorGraph(
- ggml_context* ctx,
- ggml_tensor* input,
- ggml_cgraph* gf,
- int n_frames,
- int batch
- ) {
- // Following test_10_full_model.cpp lines 532-618 EXACTLY
- // Input shape: [dim, num_bands, n_frames, batch]
- // Output: [total_out_dim, num_stems, n_frames, batch]
-
- const int DIM = dim_;
- const int NUM_BANDS = num_bands_;
- const int NUM_STEMS = num_stems_;
-
- // Calculate band_out_dims from mask_est.0.freq.{b}.mlp.4.weight shape
- std::vector<int> band_out_dims(NUM_BANDS);
- int total_out_dim = 0;
- for (int b = 0; b < NUM_BANDS; ++b) {
- // mask_est.0.freq.{b}.mlp.4.weight
- // Assuming all stems have same architecture, check stem 0
- std::string w4_name = "mask_est.0.freq." + std::to_string(b) + ".mlp.4.weight";
- ggml_tensor* w4 = GetWeight(w4_name);
- if (!w4) {
- std::cerr << "Missing weight for dim check: " << w4_name << std::endl;
- return nullptr;
- }
- band_out_dims[b] = static_cast<int>(w4->ne[1]) / 2; // GLU halves the dimension
- total_out_dim += band_out_dims[b];
- }
-
- ggml_tensor* x = input; // [D, F, T, B]
-
- // Create mask_output tensor: [total_out_dim, num_stems, n_frames, batch]
- ggml_tensor* mask_output = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, total_out_dim, NUM_STEMS, n_frames, batch);
- // No set_input needed if we cpy into it? Actually we construct it piecewise.
- // Making it zero-initialized or using views to write into it is safer.
- // ggml_set_zero(mask_output); // Not available easily in graph building usually, assumes overwritten.
-
- for (int s = 0; s < NUM_STEMS; ++s) {
- size_t mask_offset_elements = 0;
-
- for (int b = 0; b < NUM_BANDS; ++b) {
- // Extract band input: [DIM, n_frames, batch] for this band
- // Since input is same for all stems, we could cache this view?
- // GGML graph deduplication might handle it, but explicit view is fine.
- ggml_tensor* band_in = ggml_view_3d(ctx, x,
- DIM, n_frames, batch,
- x->nb[2], x->nb[3],
- b * x->nb[1]);
-
- // mask_est.{s}.freq.{b}.mlp...
- std::string prefix = "mask_est." + std::to_string(s) + ".freq." + std::to_string(b) + ".mlp.";
-
- // MLP Layer 0
- ggml_tensor* w0 = GetWeight(prefix + "0.weight");
- ggml_tensor* bias0 = GetWeight(prefix + "0.bias");
- if (!w0 || !bias0) { std::cerr << "Missing mask weights s=" << s << " b=" << b << "\n"; return nullptr; }
-
- ggml_tensor* layer0 = ggml_mul_mat(ctx, w0, band_in);
- layer0 = ggml_add(ctx, layer0, bias0);
- layer0 = ggml_tanh(ctx, layer0);
-
- // MLP Layer 2
- ggml_tensor* w2 = GetWeight(prefix + "2.weight");
- ggml_tensor* bias2 = GetWeight(prefix + "2.bias");
-
- ggml_tensor* layer2 = ggml_mul_mat(ctx, w2, layer0);
- layer2 = ggml_add(ctx, layer2, bias2);
- layer2 = ggml_tanh(ctx, layer2);
-
- // MLP Layer 4
- ggml_tensor* w4 = GetWeight(prefix + "4.weight");
- ggml_tensor* bias4 = GetWeight(prefix + "4.bias");
-
- ggml_tensor* mlp_out = ggml_mul_mat(ctx, w4, layer2);
- mlp_out = ggml_add(ctx, mlp_out, bias4);
-
- // GLU
- int dim_out = band_out_dims[b];
-
- ggml_tensor* glu_a = ggml_view_3d(ctx, mlp_out,
- dim_out, n_frames, batch,
- mlp_out->nb[1], mlp_out->nb[2], 0);
- ggml_tensor* glu_b = ggml_view_3d(ctx, mlp_out,
- dim_out, n_frames, batch,
- mlp_out->nb[1], mlp_out->nb[2],
- dim_out * sizeof(float));
-
- glu_a = ggml_cont(ctx, glu_a);
- glu_b = ggml_cont(ctx, glu_b);
-
- ggml_tensor* glu_b_sig = ggml_sigmoid(ctx, glu_b);
- ggml_tensor* band_out = ggml_mul(ctx, glu_a, glu_b_sig);
-
- // Copy to mask_output
- // Destination slice: mask_output[offset:offset+dim, s, :, :]
- // Use view_4d
- // offset in dimension 0 is mask_offset_elements
- // offset in dimension 1 is s
- size_t dest_offset_bytes = (mask_offset_elements * sizeof(float)) + (s * mask_output->nb[1]);
-
- ggml_tensor* dst_view = ggml_view_3d(ctx, mask_output,
- dim_out, n_frames, batch,
- mask_output->nb[2], // Time stride
- mask_output->nb[3], // Batch stride
- dest_offset_bytes); // Offset to correct freq-bin and stem
-
- ggml_build_forward_expand(gf, ggml_cpy(ctx, band_out, dst_view));
-
- mask_offset_elements += dim_out;
- }
- }
-
- // Ensure output
- ggml_tensor* mask_check = ggml_dup(ctx, mask_output);
- ggml_set_output(mask_check);
- ggml_build_forward_expand(gf, mask_check);
-
- return mask_check;
- }
|