#include "model.h" #include #include #include #include #include #include #include #include MelBandRoformer::MelBandRoformer() { } MelBandRoformer::~MelBandRoformer() { if (buffer_weights_) ggml_backend_buffer_free(buffer_weights_); if (backend_) ggml_backend_free(backend_); if (ctx_weights_) ggml_free(ctx_weights_); } void MelBandRoformer::Initialize(const std::string& model_path) { // Use best available backend, but allow forcing CPU if (std::getenv("MBR_FORCE_CPU")) { backend_ = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); } else { backend_ = ggml_backend_init_best(); } if (!backend_) { throw std::runtime_error("Failed to initialize backend"); } std::cout << "Using backend: " << ggml_backend_name(backend_) << std::endl; LoadWeights(model_path); } void MelBandRoformer::LoadWeights(const std::string& path) { std::cout << "Loading model from " << path << std::endl; struct gguf_init_params params = { /*.no_alloc = */ true, /*.ctx = */ &ctx_weights_, }; struct gguf_context* ctx_gguf = gguf_init_from_file(path.c_str(), params); if (!ctx_gguf) { throw std::runtime_error("Failed to load GGUF file: " + path); } // 1. Read Hyperparameters int kv_idx; kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_n_fft"); if (kv_idx >= 0) n_fft_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_hop_length"); if (kv_idx >= 0) hop_length_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_win_length"); if (kv_idx >= 0) win_length_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.dim"); if (kv_idx >= 0) dim_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.num_bands"); if (kv_idx >= 0) num_bands_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.depth"); if (kv_idx >= 0) depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); // New Parameters kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.num_stems"); if (kv_idx >= 0) num_stems_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.skip_connection"); if (kv_idx >= 0) skip_connection_ = gguf_get_val_bool(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.stft_normalized"); if (kv_idx >= 0) stft_normalized_ = gguf_get_val_bool(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.zero_dc"); if (kv_idx >= 0) zero_dc_ = gguf_get_val_bool(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.mask_estimator_depth"); if (kv_idx >= 0) mask_estimator_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.sample_rate"); if (kv_idx >= 0) sample_rate_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); // Inference defaults (optional, fallback to hardcoded values) kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.default_chunk_size"); if (kv_idx >= 0) default_chunk_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.default_num_overlap"); if (kv_idx >= 0) default_num_overlap_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx); kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.linear_transformer_depth"); if (kv_idx >= 0) { int lin_depth = (int)gguf_get_val_u32(ctx_gguf, kv_idx); if (lin_depth > 0) { std::cerr << "\n[WARNING] Model uses Linear Attention (depth=" << lin_depth << "). This is NOT supported yet. Results will be incorrect.\n" << std::endl; } } std::cout << "Model Config: n_fft=" << n_fft_ << ", hop_length=" << hop_length_ << ", num_bands=" << num_bands_ << ", dim=" << dim_ << ", depth=" << depth_ << ", num_stems=" << num_stems_ << ", skip_conn=" << skip_connection_ << std::endl; std::cout << "Inference Defaults: chunk_size=" << default_chunk_size_ << ", num_overlap=" << default_num_overlap_ << std::endl; // 2. Allocate backend buffer for ALL tensors buffer_weights_ = ggml_backend_alloc_ctx_tensors_from_buft( ctx_weights_, ggml_backend_get_default_buffer_type(backend_) ); if (!buffer_weights_) { throw std::runtime_error("Failed to allocate weight buffer"); } // 3. Read data from file and upload to backend FILE* file = fopen(path.c_str(), "rb"); if (!file) throw std::runtime_error("Cannot open file"); size_t data_offset = gguf_get_data_offset(ctx_gguf); struct ggml_tensor* t = ggml_get_first_tensor(ctx_weights_); std::vector read_buf; while (t) { int tid = gguf_find_tensor(ctx_gguf, t->name); if (tid >= 0) { size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf, tid); size_t size = ggml_nbytes(t); if (read_buf.size() < size) read_buf.resize(size); fseek(file, (long)offset, SEEK_SET); fread(read_buf.data(), 1, size, file); // Upload to backend ggml_backend_tensor_set(t, read_buf.data(), 0, size); // Cache important buffers if (std::string(t->name) == "buffer_freq_indices") { freq_indices_.resize(ggml_nelements(t)); if (t->type == GGML_TYPE_I32) { memcpy(freq_indices_.data(), read_buf.data(), size); } std::cout << " Loaded freq_indices: " << freq_indices_.size() << " indices" << std::endl; } if (std::string(t->name) == "buffer_num_bands_per_freq") { num_bands_per_freq_.resize(ggml_nelements(t)); if (t->type == GGML_TYPE_I32) { memcpy(num_bands_per_freq_.data(), read_buf.data(), size); } } if (std::string(t->name) == "buffer_num_freqs_per_band") { num_freqs_per_band_.resize(ggml_nelements(t)); if (t->type == GGML_TYPE_I32) { memcpy(num_freqs_per_band_.data(), read_buf.data(), size); } } } t = ggml_get_next_tensor(ctx_weights_, t); } fclose(file); int n_tensors = gguf_get_n_tensors(ctx_gguf); std::cout << "Loaded " << n_tensors << " tensors" << std::endl; gguf_free(ctx_gguf); } ggml_tensor* MelBandRoformer::GetWeight(const std::string& name) const { return ggml_get_tensor(ctx_weights_, name.c_str()); } std::vector MelBandRoformer::GetDimInputs() const { std::vector dim_inputs(num_bands_); for (int i = 0; i < num_bands_; ++i) { int num_freqs = num_freqs_per_band_[i]; dim_inputs[i] = num_freqs * 4; // stereo=2, complex=2 } return dim_inputs; } int MelBandRoformer::GetTotalDimInput() const { int total = 0; for (int i = 0; i < num_bands_; ++i) { total += num_freqs_per_band_[i] * 4; } return total; } // ========== Graph Building Functions ========== ggml_tensor* MelBandRoformer::BuildBandSplitGraph( ggml_context* ctx, ggml_tensor* input, ggml_cgraph* gf, int n_frames, int batch ) { // Following test_10_full_model.cpp implementation // Input: [total_dim_input, n_frames, batch] // Output: [dim, num_bands, n_frames, batch] std::vector dim_inputs = GetDimInputs(); ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, dim_, num_bands_, n_frames, batch); size_t offset_elements = 0; for (int i = 0; i < num_bands_; ++i) { int dim_in = dim_inputs[i]; // View for this band's input ggml_tensor* band_input = ggml_view_3d(ctx, input, dim_in, n_frames, batch, input->nb[1], input->nb[2], offset_elements * sizeof(float)); // Get RMSNorm gamma weight // band_split.{i}.norm.weight std::string gamma_name = "band_split." + std::to_string(i) + ".norm.weight"; ggml_tensor* gamma = GetWeight(gamma_name); if (!gamma) { std::cerr << "Missing weight: " << gamma_name << std::endl; return nullptr; } // RMSNorm ggml_tensor* normed = ggml_rms_norm(ctx, band_input, 1e-12f); normed = ggml_mul(ctx, normed, gamma); // Get Linear weight and bias // band_split.{i}.linear.weight std::string w_name = "band_split." + std::to_string(i) + ".linear.weight"; std::string b_name = "band_split." + std::to_string(i) + ".linear.bias"; ggml_tensor* weight = GetWeight(w_name); ggml_tensor* bias = GetWeight(b_name); if (!weight || !bias) { std::cerr << "Missing weight: " << w_name << " or " << b_name << std::endl; return nullptr; } // Linear projection ggml_tensor* projected = ggml_mul_mat(ctx, weight, normed); projected = ggml_add(ctx, projected, bias); // Copy to output slice ggml_tensor* out_slice = ggml_view_3d(ctx, x, dim_, n_frames, batch, x->nb[2], x->nb[3], i * x->nb[1]); ggml_build_forward_expand(gf, ggml_cpy(ctx, projected, out_slice)); offset_elements += dim_in; } return x; } ggml_tensor* MelBandRoformer::BuildTransformersGraph( ggml_context* ctx, ggml_tensor* input, ggml_cgraph* gf, ggml_tensor* pos_time_exp, ggml_tensor* pos_freq_exp, int n_frames, int batch ) { // Following test_10_full_model.cpp implementation // Input: [dim, num_bands, n_frames, batch] const int D = dim_; const int F = num_bands_; const int T = n_frames; const int B = batch; const int HEADS = heads_; const int DIM_HEAD = dim_head_; const int DIM_INNER = HEADS * DIM_HEAD; ggml_tensor* x = input; std::vector skip_outputs; for (int layer = 0; layer < depth_; ++layer) { if (skip_connection_) { for (ggml_tensor* s : skip_outputs) { x = ggml_add(ctx, x, s); } } // ========== TIME TRANSFORMER ========== // Permute: [D, F, T, B] -> [D, T, F, B] x = ggml_permute(ctx, x, 0, 2, 1, 3); x = ggml_cont(ctx, x); int fb = F * B; ggml_tensor* x_packed = ggml_reshape_3d(ctx, x, D, T, fb); std::string time_prefix = "blk." + std::to_string(layer) + ".time_attn"; std::string time_ff_prefix = "blk." + std::to_string(layer) + ".time_ff"; // Attention Block // blk.{l}.time_attn_norm.weight ggml_tensor* t_attn_norm_w = GetWeight(time_prefix + "_norm.weight"); if (!t_attn_norm_w) { std::cerr << "Missing: " << time_prefix << "_norm.weight\n"; return nullptr; } ggml_tensor* x_norm = ggml_rms_norm(ctx, x_packed, 1e-12f); x_norm = ggml_mul(ctx, x_norm, t_attn_norm_w); // blk.{l}.time_attn_qkv.weight ggml_tensor* t_qkv_w = GetWeight(time_prefix + "_qkv.weight"); if (!t_qkv_w) { std::cerr << "Missing: " << time_prefix << "_qkv.weight\n"; return nullptr; } ggml_tensor* qkv_out = ggml_mul_mat(ctx, t_qkv_w, x_norm); // Split QKV ggml_tensor* Q_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb, qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], 0); ggml_tensor* K_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb, qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], DIM_INNER*sizeof(float)); ggml_tensor* V_view = ggml_view_4d(ctx, qkv_out, DIM_HEAD, T, HEADS, fb, qkv_out->nb[1], DIM_HEAD*sizeof(float), qkv_out->nb[2], 2*DIM_INNER*sizeof(float)); ggml_tensor* Q = ggml_cont(ctx, Q_view); ggml_tensor* K = ggml_cont(ctx, K_view); ggml_tensor* V = ggml_cont(ctx, V_view); // RoPE with CUDA-compatible reshape // Original Q/K shape: [DIM_HEAD, T, HEADS, fb] // After permute: [DIM_HEAD, HEADS, T, fb] // For CUDA RoPE: reshape to [DIM_HEAD, HEADS, T*fb, 1] and use expanded pos ggml_tensor* Q_perm = ggml_permute(ctx, Q, 0, 2, 1, 3); ggml_tensor* K_perm = ggml_permute(ctx, K, 0, 2, 1, 3); ggml_tensor* Q_perm_cont = ggml_cont(ctx, Q_perm); ggml_tensor* K_perm_cont = ggml_cont(ctx, K_perm); // Reshape to merge batch(fb) into sequence for CUDA RoPE compatibility int T_fb = T * fb; ggml_tensor* Q_flat = ggml_reshape_4d(ctx, Q_perm_cont, DIM_HEAD, HEADS, T_fb, 1); ggml_tensor* K_flat = ggml_reshape_4d(ctx, K_perm_cont, DIM_HEAD, HEADS, T_fb, 1); // Use passed-in expanded position tensor (caller prepares [T*F*B] with repeating [0..T-1]) ggml_tensor* Q_rope_flat = ggml_rope_ext(ctx, Q_flat, pos_time_exp, nullptr, DIM_HEAD, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); ggml_tensor* K_rope_flat = ggml_rope_ext(ctx, K_flat, pos_time_exp, nullptr, DIM_HEAD, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // Reshape back to [DIM_HEAD, HEADS, T, fb] ggml_tensor* Q_rope_perm = ggml_reshape_4d(ctx, Q_rope_flat, DIM_HEAD, HEADS, T, fb); ggml_tensor* K_rope_perm = ggml_reshape_4d(ctx, K_rope_flat, DIM_HEAD, HEADS, T, fb); ggml_tensor* Q_rope = ggml_permute(ctx, Q_rope_perm, 0, 2, 1, 3); ggml_tensor* K_rope = ggml_permute(ctx, K_rope_perm, 0, 2, 1, 3); // Flash Attention // Inputs: [DIM_HEAD, T, HEADS, fb] // Output: [DIM_HEAD, HEADS, T, fb] (permuted) ggml_tensor* Q_fa = ggml_cont(ctx, Q_rope); ggml_tensor* K_fa = ggml_cont(ctx, K_rope); ggml_tensor* V_fa = V; // V is already contiguous [DIM_HEAD, T, HEADS, fb] float scale = 1.0f / sqrtf(static_cast(DIM_HEAD)); ggml_tensor* attn_out_fa = ggml_flash_attn_ext(ctx, Q_fa, K_fa, V_fa, nullptr, scale, 0.0f, 0.0f); // Permute back to [DIM_HEAD, T, HEADS, fb] to match original flow ggml_tensor* attn_out_perm = ggml_permute(ctx, attn_out_fa, 0, 2, 1, 3); ggml_tensor* attn_out_raw = ggml_cont(ctx, attn_out_perm); // Gates // blk.{l}.time_attn_gate.weight/bias ggml_tensor* t_gate_w = GetWeight(time_prefix + "_gate.weight"); ggml_tensor* t_gate_b = GetWeight(time_prefix + "_gate.bias"); if (!t_gate_w || !t_gate_b) { std::cerr << "Missing gates weights\n"; return nullptr; } ggml_tensor* gates = ggml_mul_mat(ctx, t_gate_w, x_norm); gates = ggml_add(ctx, gates, t_gate_b); gates = ggml_sigmoid(ctx, gates); ggml_tensor* gates_perm = ggml_permute(ctx, gates, 1, 0, 2, 3); ggml_tensor* gates_bcast = ggml_view_4d(ctx, gates_perm, 1, T, HEADS, fb, gates_perm->nb[0], gates_perm->nb[1], gates_perm->nb[2], 0); ggml_tensor* gated_out = ggml_mul(ctx, attn_out_raw, gates_bcast); ggml_tensor* out_perm = ggml_permute(ctx, gated_out, 0, 2, 1, 3); ggml_tensor* out_cont = ggml_cont(ctx, out_perm); ggml_tensor* out_flat = ggml_reshape_3d(ctx, out_cont, DIM_INNER, T, fb); // blk.{l}.time_attn_out.weight ggml_tensor* t_attn_out_w = GetWeight(time_prefix + "_out.weight"); if (!t_attn_out_w) { std::cerr << "Missing to_out_weight\n"; return nullptr; } ggml_tensor* attn_block_out = ggml_mul_mat(ctx, t_attn_out_w, out_flat); ggml_tensor* x_resid1 = ggml_add(ctx, x_packed, attn_block_out); // FeedForward Block // blk.{l}.time_ff_norm.weight ggml_tensor* t_ff_norm_w = GetWeight(time_ff_prefix + "_norm.weight"); if (!t_ff_norm_w) { std::cerr << "Missing ff norm\n"; return nullptr; } ggml_tensor* x_resid1_norm = ggml_rms_norm(ctx, x_resid1, 1e-12f); x_resid1_norm = ggml_mul(ctx, x_resid1_norm, t_ff_norm_w); // blk.{l}.time_ff_in.weight/bias ggml_tensor* t_ff_in_w = GetWeight(time_ff_prefix + "_in.weight"); ggml_tensor* t_ff_in_b = GetWeight(time_ff_prefix + "_in.bias"); if (!t_ff_in_w || !t_ff_in_b) { std::cerr << "Missing ff in weights\n"; return nullptr; } ggml_tensor* ff_proj_in = ggml_mul_mat(ctx, t_ff_in_w, x_resid1_norm); ff_proj_in = ggml_add(ctx, ff_proj_in, t_ff_in_b); ggml_tensor* gelu_out = ggml_gelu_erf(ctx, ff_proj_in); // blk.{l}.time_ff_out.weight/bias ggml_tensor* t_ff_out_w = GetWeight(time_ff_prefix + "_out.weight"); ggml_tensor* t_ff_out_b = GetWeight(time_ff_prefix + "_out.bias"); if (!t_ff_out_w || !t_ff_out_b) { std::cerr << "Missing ff out weights\n"; return nullptr; } ggml_tensor* ff_block_out = ggml_mul_mat(ctx, t_ff_out_w, gelu_out); ff_block_out = ggml_add(ctx, ff_block_out, t_ff_out_b); x_packed = ggml_add(ctx, x_resid1, ff_block_out); // Time Transformer Final Norm // blk.{l}.time_norm.weight std::string time_norm_name = "blk." + std::to_string(layer) + ".time_norm.weight"; ggml_tensor* time_norm_w = GetWeight(time_norm_name); if (!time_norm_w) { std::cerr << "Missing: " << time_norm_name << "\n"; return nullptr; } x_packed = ggml_rms_norm(ctx, x_packed, 1e-12f); x_packed = ggml_mul(ctx, x_packed, time_norm_w); x = ggml_reshape_4d(ctx, x_packed, D, T, F, B); x = ggml_permute(ctx, x, 0, 2, 1, 3); x = ggml_cont(ctx, x); // ========== FREQ TRANSFORMER ========== int tb = T * B; ggml_tensor* x_freq_packed = ggml_reshape_3d(ctx, x, D, F, tb); std::string freq_prefix = "blk." + std::to_string(layer) + ".freq_attn"; std::string freq_ff_prefix = "blk." + std::to_string(layer) + ".freq_ff"; ggml_tensor* f_attn_norm_w = GetWeight(freq_prefix + "_norm.weight"); if (!f_attn_norm_w) { std::cerr << "Missing freq norm\n"; return nullptr; } ggml_tensor* x_fnorm = ggml_rms_norm(ctx, x_freq_packed, 1e-12f); x_fnorm = ggml_mul(ctx, x_fnorm, f_attn_norm_w); ggml_tensor* f_qkv_w = GetWeight(freq_prefix + "_qkv.weight"); if (!f_qkv_w) { std::cerr << "Missing freq qkv\n"; return nullptr; } ggml_tensor* f_qkv_out = ggml_mul_mat(ctx, f_qkv_w, x_fnorm); ggml_tensor* fQ_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb, f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], 0); ggml_tensor* fK_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb, f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], DIM_INNER*sizeof(float)); ggml_tensor* fV_view = ggml_view_4d(ctx, f_qkv_out, DIM_HEAD, F, HEADS, tb, f_qkv_out->nb[1], DIM_HEAD*sizeof(float), f_qkv_out->nb[2], 2*DIM_INNER*sizeof(float)); ggml_tensor* fQ = ggml_cont(ctx, fQ_view); ggml_tensor* fK = ggml_cont(ctx, fK_view); ggml_tensor* fV = ggml_cont(ctx, fV_view); // RoPE with CUDA-compatible reshape for Freq Transformer // fQ/fK shape after permute: [DIM_HEAD, HEADS, F, tb] ggml_tensor* fQ_perm = ggml_permute(ctx, fQ, 0, 2, 1, 3); ggml_tensor* fK_perm = ggml_permute(ctx, fK, 0, 2, 1, 3); ggml_tensor* fQ_perm_cont = ggml_cont(ctx, fQ_perm); ggml_tensor* fK_perm_cont = ggml_cont(ctx, fK_perm); // Reshape to merge batch(tb) into sequence for CUDA RoPE compatibility int F_tb = F * tb; ggml_tensor* fQ_flat = ggml_reshape_4d(ctx, fQ_perm_cont, DIM_HEAD, HEADS, F_tb, 1); ggml_tensor* fK_flat = ggml_reshape_4d(ctx, fK_perm_cont, DIM_HEAD, HEADS, F_tb, 1); // Use passed-in expanded position tensor (caller prepares [F*T*B] with repeating [0..F-1]) ggml_tensor* fQ_rope_flat = ggml_rope_ext(ctx, fQ_flat, pos_freq_exp, nullptr, DIM_HEAD, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); ggml_tensor* fK_rope_flat = ggml_rope_ext(ctx, fK_flat, pos_freq_exp, nullptr, DIM_HEAD, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // Reshape back to [DIM_HEAD, HEADS, F, tb] ggml_tensor* fQ_rope_perm = ggml_reshape_4d(ctx, fQ_rope_flat, DIM_HEAD, HEADS, F, tb); ggml_tensor* fK_rope_perm = ggml_reshape_4d(ctx, fK_rope_flat, DIM_HEAD, HEADS, F, tb); ggml_tensor* fQ_rope = ggml_permute(ctx, fQ_rope_perm, 0, 2, 1, 3); ggml_tensor* fK_rope = ggml_permute(ctx, fK_rope_perm, 0, 2, 1, 3); // Flash Attention (Freq) // Inputs: [DIM_HEAD, F, HEADS, tb] ggml_tensor* fQ_fa = ggml_cont(ctx, fQ_rope); ggml_tensor* fK_fa = ggml_cont(ctx, fK_rope); ggml_tensor* fV_fa = fV; // fV is contiguous [DIM_HEAD, F, HEADS, tb] // float scale is already defined in scope (Time Transformer block) or re-define if shadowed loop? // Actually 'scale' was defined inside the Time Transformer loop, so it persists? // No, Freq Transformer is in the same loop logic? // Let's check scope. It's in the same 'layer' loop. // But previously I removed the definition line in Time Transformer too? No, I added it back above. // Wait, best to redeclare or rely on scope? // Time Transformer code block vs Freq Transformer. // Let's just use the value. // Re-reading Freq Block: // Need to be safe. Redefine 'scale' if needed or ensuring it's available. // Previous search showed `float scale` was defined in Time Block. // If Time block is just sequential code, `scale` is available. // But I removed the line in Time block in the previous step (lines 307-319 replaced). // So I need to add it back in Time block (done in chunk 1). // For Freq block, if it's in same scope, it's fine. // However, standard good practice: // float scale = 1.0f / sqrtf(static_cast(DIM_HEAD)); // Redefinition might error if same scope. // Let's check file content to see if Freq block is separately scoped. // It's in `for (int layer...) { ... Time ... Freq ... }`. // So `scale` defined in Time part is visible in Freq part. // So I don't need to define it again, just ensure it IS defined in Time part. ggml_tensor* f_attn_out_fa = ggml_flash_attn_ext(ctx, fQ_fa, fK_fa, fV_fa, nullptr, scale, 0.0f, 0.0f); // Permute output back to [DIM_HEAD, F, HEADS, tb] ggml_tensor* f_attn_out_perm = ggml_permute(ctx, f_attn_out_fa, 0, 2, 1, 3); ggml_tensor* f_attn_out_raw = ggml_cont(ctx, f_attn_out_perm); ggml_tensor* f_gate_w = GetWeight(freq_prefix + "_gate.weight"); ggml_tensor* f_gate_b = GetWeight(freq_prefix + "_gate.bias"); if (!f_gate_w || !f_gate_b) { std::cerr << "Missing freq gates\n"; return nullptr; } ggml_tensor* f_gates = ggml_mul_mat(ctx, f_gate_w, x_fnorm); f_gates = ggml_add(ctx, f_gates, f_gate_b); f_gates = ggml_sigmoid(ctx, f_gates); ggml_tensor* f_gates_perm = ggml_permute(ctx, f_gates, 1, 0, 2, 3); ggml_tensor* f_gates_bcast = ggml_view_4d(ctx, f_gates_perm, 1, F, HEADS, tb, f_gates_perm->nb[0], f_gates_perm->nb[1], f_gates_perm->nb[2], 0); ggml_tensor* f_gated_out = ggml_mul(ctx, f_attn_out_raw, f_gates_bcast); ggml_tensor* f_out_perm = ggml_permute(ctx, f_gated_out, 0, 2, 1, 3); ggml_tensor* f_out_cont = ggml_cont(ctx, f_out_perm); ggml_tensor* f_out_flat = ggml_reshape_3d(ctx, f_out_cont, DIM_INNER, F, tb); ggml_tensor* f_attn_out_w = GetWeight(freq_prefix + "_out.weight"); if (!f_attn_out_w) { std::cerr << "Missing freq to_out\n"; return nullptr; } ggml_tensor* f_attn_block_out = ggml_mul_mat(ctx, f_attn_out_w, f_out_flat); ggml_tensor* f_x_resid1 = ggml_add(ctx, x_freq_packed, f_attn_block_out); // Freq FeedForward ggml_tensor* f_ff_norm_w = GetWeight(freq_ff_prefix + "_norm.weight"); if (!f_ff_norm_w) { std::cerr << "Missing freq ff norm\n"; return nullptr; } ggml_tensor* f_x_resid1_norm = ggml_rms_norm(ctx, f_x_resid1, 1e-12f); f_x_resid1_norm = ggml_mul(ctx, f_x_resid1_norm, f_ff_norm_w); x_fnorm = ggml_mul(ctx, x_fnorm, f_attn_norm_w); ggml_tensor* f_ff_in_w = GetWeight(freq_ff_prefix + "_in.weight"); ggml_tensor* f_ff_in_b = GetWeight(freq_ff_prefix + "_in.bias"); if (!f_ff_in_w || !f_ff_in_b) { std::cerr << "Missing freq ff in\n"; return nullptr; } ggml_tensor* f_ff_proj_in = ggml_mul_mat(ctx, f_ff_in_w, f_x_resid1_norm); f_ff_proj_in = ggml_add(ctx, f_ff_proj_in, f_ff_in_b); ggml_tensor* f_gelu_out = ggml_gelu_erf(ctx, f_ff_proj_in); ggml_tensor* f_ff_out_w = GetWeight(freq_ff_prefix + "_out.weight"); ggml_tensor* f_ff_out_b = GetWeight(freq_ff_prefix + "_out.bias"); if (!f_ff_out_w || !f_ff_out_b) { std::cerr << "Missing freq ff out\n"; return nullptr; } ggml_tensor* f_ff_block_out = ggml_mul_mat(ctx, f_ff_out_w, f_gelu_out); f_ff_block_out = ggml_add(ctx, f_ff_block_out, f_ff_out_b); x_freq_packed = ggml_add(ctx, f_x_resid1, f_ff_block_out); // Freq Transformer Final Norm // blk.{l}.freq_norm.weight std::string freq_norm_name = "blk." + std::to_string(layer) + ".freq_norm.weight"; ggml_tensor* freq_norm_w = GetWeight(freq_norm_name); if (!freq_norm_w) { std::cerr << "Missing: " << freq_norm_name << "\n"; return nullptr; } x_freq_packed = ggml_rms_norm(ctx, x_freq_packed, 1e-12f); x_freq_packed = ggml_mul(ctx, x_freq_packed, freq_norm_w); x = ggml_reshape_4d(ctx, x_freq_packed, D, F, T, B); if (skip_connection_) { skip_outputs.push_back(x); } } return x; } ggml_tensor* MelBandRoformer::BuildMaskEstimatorGraph( ggml_context* ctx, ggml_tensor* input, ggml_cgraph* gf, int n_frames, int batch ) { // Following test_10_full_model.cpp lines 532-618 EXACTLY // Input shape: [dim, num_bands, n_frames, batch] // Output: [total_out_dim, num_stems, n_frames, batch] const int DIM = dim_; const int NUM_BANDS = num_bands_; const int NUM_STEMS = num_stems_; // Calculate band_out_dims from mask_est.0.freq.{b}.mlp.4.weight shape std::vector band_out_dims(NUM_BANDS); int total_out_dim = 0; for (int b = 0; b < NUM_BANDS; ++b) { // mask_est.0.freq.{b}.mlp.4.weight // Assuming all stems have same architecture, check stem 0 std::string w4_name = "mask_est.0.freq." + std::to_string(b) + ".mlp.4.weight"; ggml_tensor* w4 = GetWeight(w4_name); if (!w4) { std::cerr << "Missing weight for dim check: " << w4_name << std::endl; return nullptr; } band_out_dims[b] = static_cast(w4->ne[1]) / 2; // GLU halves the dimension total_out_dim += band_out_dims[b]; } ggml_tensor* x = input; // [D, F, T, B] // Create mask_output tensor: [total_out_dim, num_stems, n_frames, batch] ggml_tensor* mask_output = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, total_out_dim, NUM_STEMS, n_frames, batch); // No set_input needed if we cpy into it? Actually we construct it piecewise. // Making it zero-initialized or using views to write into it is safer. // ggml_set_zero(mask_output); // Not available easily in graph building usually, assumes overwritten. for (int s = 0; s < NUM_STEMS; ++s) { size_t mask_offset_elements = 0; for (int b = 0; b < NUM_BANDS; ++b) { // Extract band input: [DIM, n_frames, batch] for this band // Since input is same for all stems, we could cache this view? // GGML graph deduplication might handle it, but explicit view is fine. ggml_tensor* band_in = ggml_view_3d(ctx, x, DIM, n_frames, batch, x->nb[2], x->nb[3], b * x->nb[1]); // mask_est.{s}.freq.{b}.mlp... std::string prefix = "mask_est." + std::to_string(s) + ".freq." + std::to_string(b) + ".mlp."; // MLP Layer 0 ggml_tensor* w0 = GetWeight(prefix + "0.weight"); ggml_tensor* bias0 = GetWeight(prefix + "0.bias"); if (!w0 || !bias0) { std::cerr << "Missing mask weights s=" << s << " b=" << b << "\n"; return nullptr; } ggml_tensor* layer0 = ggml_mul_mat(ctx, w0, band_in); layer0 = ggml_add(ctx, layer0, bias0); layer0 = ggml_tanh(ctx, layer0); // MLP Layer 2 ggml_tensor* w2 = GetWeight(prefix + "2.weight"); ggml_tensor* bias2 = GetWeight(prefix + "2.bias"); ggml_tensor* layer2 = ggml_mul_mat(ctx, w2, layer0); layer2 = ggml_add(ctx, layer2, bias2); layer2 = ggml_tanh(ctx, layer2); // MLP Layer 4 ggml_tensor* w4 = GetWeight(prefix + "4.weight"); ggml_tensor* bias4 = GetWeight(prefix + "4.bias"); ggml_tensor* mlp_out = ggml_mul_mat(ctx, w4, layer2); mlp_out = ggml_add(ctx, mlp_out, bias4); // GLU int dim_out = band_out_dims[b]; ggml_tensor* glu_a = ggml_view_3d(ctx, mlp_out, dim_out, n_frames, batch, mlp_out->nb[1], mlp_out->nb[2], 0); ggml_tensor* glu_b = ggml_view_3d(ctx, mlp_out, dim_out, n_frames, batch, mlp_out->nb[1], mlp_out->nb[2], dim_out * sizeof(float)); glu_a = ggml_cont(ctx, glu_a); glu_b = ggml_cont(ctx, glu_b); ggml_tensor* glu_b_sig = ggml_sigmoid(ctx, glu_b); ggml_tensor* band_out = ggml_mul(ctx, glu_a, glu_b_sig); // Copy to mask_output // Destination slice: mask_output[offset:offset+dim, s, :, :] // Use view_4d // offset in dimension 0 is mask_offset_elements // offset in dimension 1 is s size_t dest_offset_bytes = (mask_offset_elements * sizeof(float)) + (s * mask_output->nb[1]); ggml_tensor* dst_view = ggml_view_3d(ctx, mask_output, dim_out, n_frames, batch, mask_output->nb[2], // Time stride mask_output->nb[3], // Batch stride dest_offset_bytes); // Offset to correct freq-bin and stem ggml_build_forward_expand(gf, ggml_cpy(ctx, band_out, dst_view)); mask_offset_elements += dim_out; } } // Ensure output ggml_tensor* mask_check = ggml_dup(ctx, mask_output); ggml_set_output(mask_check); ggml_build_forward_expand(gf, mask_check); return mask_check; }