|
|
@@ -66,6 +66,9 @@ void BSRoformer::LoadWeights(const std::string& path) {
|
|
|
if (architecture_ == "bs_roformer") {
|
|
|
has_final_norm_ = true;
|
|
|
transformer_norm_output_ = false;
|
|
|
+ } else if (architecture_ == "bs_roformer_v2") {
|
|
|
+ is_v2_model_ = true;
|
|
|
+ // V2-specific logic can be added here if needed
|
|
|
} else {
|
|
|
// mel_band_roformer
|
|
|
has_final_norm_ = false;
|
|
|
@@ -113,6 +116,29 @@ void BSRoformer::LoadWeights(const std::string& path) {
|
|
|
|
|
|
kv_idx = gguf_find_key(ctx_gguf, (kp + "sample_rate").c_str());
|
|
|
if (kv_idx >= 0) sample_rate_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ if (is_v2_model_) {
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "time_transformer_depth").c_str());
|
|
|
+ if (kv_idx >= 0) time_transformer_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "freq_transformer_depth").c_str());
|
|
|
+ if (kv_idx >= 0) freq_transformer_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "num_key_value_heads").c_str());
|
|
|
+ if (kv_idx >= 0) num_key_value_heads_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "intermediate_size").c_str());
|
|
|
+ if (kv_idx >= 0) intermediate_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "num_input_channels").c_str());
|
|
|
+ if (kv_idx >= 0) num_input_channels_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "band_proj_size").c_str());
|
|
|
+ if (kv_idx >= 0) band_proj_size_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+
|
|
|
+ kv_idx = gguf_find_key(ctx_gguf, (kp + "register_token_num").c_str());
|
|
|
+ if (kv_idx >= 0) register_token_num_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
|
|
|
+ }
|
|
|
|
|
|
// Inference defaults (optional, fallback to hardcoded values)
|
|
|
kv_idx = gguf_find_key(ctx_gguf, (kp + "default_chunk_size").c_str());
|
|
|
@@ -229,7 +255,13 @@ std::vector<int> BSRoformer::GetDimInputs() const {
|
|
|
}
|
|
|
|
|
|
int BSRoformer::GetTotalDimInput() const {
|
|
|
- if (architecture_ == "bs") {
|
|
|
+ if (is_v2_model_) {
|
|
|
+ int total = 0;
|
|
|
+ for (int i = 0; i < num_bands_; ++i) {
|
|
|
+ total += num_freqs_per_band_[i] * 2;
|
|
|
+ }
|
|
|
+ return total;
|
|
|
+ } else if (architecture_ == "bs_roformer") {
|
|
|
// BS: All frequencies * stereo * complex
|
|
|
int n_freq = n_fft_ / 2 + 1;
|
|
|
return n_freq * 2 * 2; // freq * stereo * complex
|
|
|
@@ -242,7 +274,7 @@ int BSRoformer::GetTotalDimInput() const {
|
|
|
return total;
|
|
|
}
|
|
|
|
|
|
-// ========== Graph Building Functions ==========
|
|
|
+// ========== Graph Building Functions ==========
|
|
|
|
|
|
ggml_tensor* BSRoformer::BuildBandSplitGraph(
|
|
|
ggml_context* ctx,
|
|
|
@@ -251,6 +283,46 @@ ggml_tensor* BSRoformer::BuildBandSplitGraph(
|
|
|
int n_frames,
|
|
|
int batch
|
|
|
) {
|
|
|
+ if (is_v2_model_) {
|
|
|
+ // V2 model band split
|
|
|
+ ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, band_proj_size_, num_bands_, n_frames, batch);
|
|
|
+
|
|
|
+ size_t offset_elements = 0;
|
|
|
+ for (int i = 0; i < num_bands_; ++i) {
|
|
|
+ int dim_in = num_freqs_per_band_[i] * 2; // V2 uses 2 instead of 4
|
|
|
+
|
|
|
+ ggml_tensor* band_input = ggml_view_3d(ctx, input,
|
|
|
+ dim_in, n_frames, batch,
|
|
|
+ input->nb[1], input->nb[2],
|
|
|
+ offset_elements * sizeof(float));
|
|
|
+
|
|
|
+ std::string norm_name = "band_split." + std::to_string(i) + ".norm.weight";
|
|
|
+ ggml_tensor* norm_w = GetWeight(norm_name);
|
|
|
+ if (!norm_w) { std::cerr << "Missing weight: " << norm_name << std::endl; return nullptr; }
|
|
|
+
|
|
|
+ ggml_tensor* normed = ggml_rms_norm(ctx, band_input, 1e-6f);
|
|
|
+ normed = ggml_mul(ctx, normed, norm_w);
|
|
|
+
|
|
|
+ std::string linear_w_name = "band_split." + std::to_string(i) + ".linear.weight";
|
|
|
+ std::string linear_b_name = "band_split." + std::to_string(i) + ".linear.bias";
|
|
|
+ ggml_tensor* linear_w = GetWeight(linear_w_name);
|
|
|
+ ggml_tensor* linear_b = GetWeight(linear_b_name);
|
|
|
+ if (!linear_w || !linear_b) { std::cerr << "Missing weights for band " << i << std::endl; return nullptr; }
|
|
|
+
|
|
|
+ ggml_tensor* projected = ggml_mul_mat(ctx, linear_w, normed);
|
|
|
+ projected = ggml_add(ctx, projected, linear_b);
|
|
|
+
|
|
|
+ ggml_tensor* out_slice = ggml_view_3d(ctx, x,
|
|
|
+ band_proj_size_, n_frames, batch,
|
|
|
+ x->nb[2], x->nb[3],
|
|
|
+ i * x->nb[1]);
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx, projected, out_slice));
|
|
|
+ offset_elements += dim_in;
|
|
|
+ }
|
|
|
+ return x;
|
|
|
+ }
|
|
|
+
|
|
|
// Following test_10_full_model.cpp implementation
|
|
|
// Input: [total_dim_input, n_frames, batch]
|
|
|
// Output: [dim, num_bands, n_frames, batch]
|
|
|
@@ -321,6 +393,9 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
|
|
|
int n_frames,
|
|
|
int batch
|
|
|
) {
|
|
|
+ if (is_v2_model_) {
|
|
|
+ return BuildTransformersGraphV2(ctx, input, gf, pos_time_exp, pos_freq_exp, n_frames, batch);
|
|
|
+ }
|
|
|
// Following test_10_full_model.cpp implementation
|
|
|
// Input: [dim, num_bands, n_frames, batch]
|
|
|
|
|
|
@@ -341,7 +416,7 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
|
|
|
x = ggml_add(ctx, x, s);
|
|
|
}
|
|
|
}
|
|
|
- // ========== TIME TRANSFORMER ==========
|
|
|
+ // ========== TIME TRANSFORMER ==========
|
|
|
// Permute: [D, F, T, B] -> [D, T, F, B]
|
|
|
x = ggml_permute(ctx, x, 0, 2, 1, 3);
|
|
|
x = ggml_cont(ctx, x);
|
|
|
@@ -493,7 +568,7 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
|
|
|
x = ggml_permute(ctx, x, 0, 2, 1, 3);
|
|
|
x = ggml_cont(ctx, x);
|
|
|
|
|
|
- // ========== FREQ TRANSFORMER ==========
|
|
|
+ // ========== FREQ TRANSFORMER ==========
|
|
|
int tb = T * B;
|
|
|
ggml_tensor* x_freq_packed = ggml_reshape_3d(ctx, x, D, F, tb);
|
|
|
|
|
|
@@ -558,12 +633,8 @@ ggml_tensor* BSRoformer::BuildTransformersGraph(
|
|
|
ggml_tensor* fV_fa = fV; // fV is contiguous [DIM_HEAD, F, HEADS, tb]
|
|
|
|
|
|
// float scale is already defined in scope (Time Transformer block) or re-define if shadowed loop?
|
|
|
- // Actually 'scale' was defined inside the Time Transformer loop, so it persists?
|
|
|
- // No, Freq Transformer is in the same loop logic?
|
|
|
- // Let's check scope. It's in the same 'layer' loop.
|
|
|
- // But previously I removed the definition line in Time Transformer too? No, I added it back above.
|
|
|
- // Wait, best to redeclare or rely on scope?
|
|
|
- // Time Transformer code block vs Freq Transformer.
|
|
|
+ // Actually 'scale' was defined inside the Time Transformer loop, so it persists?
|
|
|
+ // No, Freq Transformer is in the same loop logic?
|
|
|
// Let's just use the value.
|
|
|
// Re-reading Freq Block:
|
|
|
// Need to be safe. Redefine 'scale' if needed or ensuring it's available.
|
|
|
@@ -707,7 +778,7 @@ ggml_tensor* BSRoformer::BuildMaskEstimatorGraph(
|
|
|
total_out_dim += band_out_dims[b];
|
|
|
}
|
|
|
|
|
|
- ggml_tensor* x = input; // [D, F, T, B]
|
|
|
+ ggml_tensor* x = input; // [D, F, T, B]
|
|
|
|
|
|
// Create mask_output tensor: [total_out_dim, num_stems, n_frames, batch]
|
|
|
ggml_tensor* mask_output = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, total_out_dim, NUM_STEMS, n_frames, batch);
|
|
|
@@ -800,3 +871,105 @@ ggml_tensor* BSRoformer::BuildMaskEstimatorGraph(
|
|
|
|
|
|
return mask_check;
|
|
|
}
|
|
|
+
|
|
|
+ggml_tensor* BSRoformer::BuildTransformersGraphV2(
|
|
|
+ ggml_context* ctx,
|
|
|
+ ggml_tensor* input,
|
|
|
+ ggml_cgraph* gf,
|
|
|
+ ggml_tensor* pos_time_exp,
|
|
|
+ ggml_tensor* pos_freq_exp,
|
|
|
+ int n_frames,
|
|
|
+ int batch
|
|
|
+) {
|
|
|
+ ggml_tensor* x = input;
|
|
|
+
|
|
|
+ for (int layer = 0; layer < depth_; ++layer) {
|
|
|
+ // Time Transformer
|
|
|
+ for (int time_layer = 0; time_layer < time_transformer_depth_; ++time_layer) {
|
|
|
+ std::string time_prefix = "blk." + std::to_string(layer) + ".time_attn." + std::to_string(time_layer);
|
|
|
+ ggml_tensor* x_norm = ggml_rms_norm(ctx, x, 1e-6f);
|
|
|
+ ggml_tensor* t_attn_norm_w = GetWeight(time_prefix + ".norm.weight");
|
|
|
+ x_norm = ggml_mul(ctx, x_norm, t_attn_norm_w);
|
|
|
+
|
|
|
+ ggml_tensor* t_qkv_w = GetWeight(time_prefix + ".qkv.weight");
|
|
|
+ ggml_tensor* qkv_out = ggml_mul_mat(ctx, t_qkv_w, x_norm);
|
|
|
+
|
|
|
+ // Split Q, K, V
|
|
|
+ ggml_tensor* Q = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], 0);
|
|
|
+ ggml_tensor* K = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], dim_ * sizeof(float));
|
|
|
+ ggml_tensor* V = ggml_view_2d(ctx, qkv_out, dim_, n_frames, qkv_out->nb[1], dim_ * 2 * sizeof(float));
|
|
|
+
|
|
|
+ // RoPE
|
|
|
+ Q = ggml_rope_ext(ctx, Q, pos_time_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
|
+ K = ggml_rope_ext(ctx, K, pos_time_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
|
+
|
|
|
+ // Attention
|
|
|
+ ggml_tensor* attn = ggml_flash_attn_ext(ctx, Q, K, V, nullptr, 1.0f / sqrtf(dim_head_), 0.0f, 0.0f);
|
|
|
+
|
|
|
+ // Output projection
|
|
|
+ ggml_tensor* t_out_w = GetWeight(time_prefix + ".out.weight");
|
|
|
+ attn = ggml_mul_mat(ctx, t_out_w, attn);
|
|
|
+
|
|
|
+ x = ggml_add(ctx, x, attn);
|
|
|
+
|
|
|
+ // MLP
|
|
|
+ ggml_tensor* x_mlp = ggml_rms_norm(ctx, x, 1e-6f);
|
|
|
+ ggml_tensor* mlp_norm_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".norm.weight");
|
|
|
+ x_mlp = ggml_mul(ctx, x_mlp, mlp_norm_w);
|
|
|
+
|
|
|
+ ggml_tensor* mlp_in_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".in.weight");
|
|
|
+ x_mlp = ggml_mul_mat(ctx, mlp_in_w, x_mlp);
|
|
|
+ x_mlp = ggml_gelu(ctx, x_mlp);
|
|
|
+
|
|
|
+ ggml_tensor* mlp_out_w = GetWeight("blk." + std::to_string(layer) + ".time_ff." + std::to_string(time_layer) + ".out.weight");
|
|
|
+ x_mlp = ggml_mul_mat(ctx, mlp_out_w, x_mlp);
|
|
|
+
|
|
|
+ x = ggml_add(ctx, x, x_mlp);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Freq Transformer
|
|
|
+ for (int freq_layer = 0; freq_layer < freq_transformer_depth_; ++freq_layer) {
|
|
|
+ std::string freq_prefix = "blk." + std::to_string(layer) + ".freq_attn." + std::to_string(freq_layer);
|
|
|
+ ggml_tensor* x_norm = ggml_rms_norm(ctx, x, 1e-6f);
|
|
|
+ ggml_tensor* f_attn_norm_w = GetWeight(freq_prefix + ".norm.weight");
|
|
|
+ x_norm = ggml_mul(ctx, x_norm, f_attn_norm_w);
|
|
|
+
|
|
|
+ ggml_tensor* f_qkv_w = GetWeight(freq_prefix + ".qkv.weight");
|
|
|
+ ggml_tensor* qkv_out = ggml_mul_mat(ctx, f_qkv_w, x_norm);
|
|
|
+
|
|
|
+ // Split Q, K, V
|
|
|
+ ggml_tensor* Q = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], 0);
|
|
|
+ ggml_tensor* K = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], dim_ * sizeof(float));
|
|
|
+ ggml_tensor* V = ggml_view_2d(ctx, qkv_out, dim_, num_bands_, qkv_out->nb[1], dim_ * 2 * sizeof(float));
|
|
|
+
|
|
|
+ // RoPE
|
|
|
+ Q = ggml_rope_ext(ctx, Q, pos_freq_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
|
+ K = ggml_rope_ext(ctx, K, pos_freq_exp, nullptr, dim_head_, GGML_ROPE_TYPE_NORMAL, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
|
|
+
|
|
|
+ // Attention
|
|
|
+ ggml_tensor* attn = ggml_flash_attn_ext(ctx, Q, K, V, nullptr, 1.0f / sqrtf(dim_head_), 0.0f, 0.0f);
|
|
|
+
|
|
|
+ // Output projection
|
|
|
+ ggml_tensor* f_out_w = GetWeight(freq_prefix + ".out.weight");
|
|
|
+ attn = ggml_mul_mat(ctx, f_out_w, attn);
|
|
|
+
|
|
|
+ x = ggml_add(ctx, x, attn);
|
|
|
+
|
|
|
+ // MLP
|
|
|
+ ggml_tensor* x_mlp = ggml_rms_norm(ctx, x, 1e-6f);
|
|
|
+ ggml_tensor* mlp_norm_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".norm.weight");
|
|
|
+ x_mlp = ggml_mul(ctx, x_mlp, mlp_norm_w);
|
|
|
+
|
|
|
+ ggml_tensor* mlp_in_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".in.weight");
|
|
|
+ x_mlp = ggml_mul_mat(ctx, mlp_in_w, x_mlp);
|
|
|
+ x_mlp = ggml_gelu(ctx, x_mlp);
|
|
|
+
|
|
|
+ ggml_tensor* mlp_out_w = GetWeight("blk." + std::to_string(layer) + ".freq_ff." + std::to_string(freq_layer) + ".out.weight");
|
|
|
+ x_mlp = ggml_mul_mat(ctx, mlp_out_w, x_mlp);
|
|
|
+
|
|
|
+ x = ggml_add(ctx, x, x_mlp);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return x;
|
|
|
+}
|