Browse Source

feat(cli): enhance argument parsing with proper error handling
fix(cli): use dynamic sample rate from model instead of hardcoded value
fix(audio): correct stereo samples calculation for mono-to-stereo conversion
refactor(cli): optimize memory usage in output buffer creationadd sample rate support from GGUF metadata

沉默の金 5 tháng trước cách đây
mục cha
commit
5e006bccbc
5 tập tin đã thay đổi với 39 bổ sung14 xóa
  1. 27 12
      cli/main.cpp
  2. 1 0
      include/mel_band_roformer/inference.h
  3. 4 0
      src/inference.cpp
  4. 3 0
      src/model.cpp
  5. 4 2
      src/model.h

+ 27 - 12
cli/main.cpp

@@ -43,19 +43,29 @@ int main(int argc, char* argv[]) {
     for (int i = 4; i < argc; ++i) {
         std::string arg = argv[i];
         if (arg == "--chunk-size" && i + 1 < argc) {
-            chunk_size = std::atoi(argv[++i]);
-            if (chunk_size <= 0) {
-                std::cerr << "Error: chunk-size must be a positive integer" << std::endl;
+            try {
+                chunk_size = std::stoi(argv[++i]);
+                if (chunk_size <= 0) {
+                     std::cerr << "Error: chunk-size must be a positive integer" << std::endl;
+                     return 1;
+                }
+                chunk_size_set = true;
+            } catch (...) {
+                std::cerr << "Error: invalid chunk-size" << std::endl;
                 return 1;
             }
-            chunk_size_set = true;
         } else if (arg == "--overlap" && i + 1 < argc) {
-            num_overlap = std::atoi(argv[++i]);
-            if (num_overlap < 1) {
-                std::cerr << "Error: overlap must be at least 1" << std::endl;
+            try {
+                num_overlap = std::stoi(argv[++i]);
+                if (num_overlap < 1) {
+                    std::cerr << "Error: overlap must be at least 1" << std::endl;
+                    return 1;
+                }
+                num_overlap_set = true;
+             } catch (...) {
+                std::cerr << "Error: invalid overlap" << std::endl;
                 return 1;
             }
-            num_overlap_set = true;
         } else {
             std::cerr << "Unknown option: " << arg << std::endl;
             print_usage(argv[0]);
@@ -85,8 +95,12 @@ int main(int argc, char* argv[]) {
                   << input_audio.sampleRate << " Hz" << std::endl;
 
         // 1. Check Sample Rate
-        if (input_audio.sampleRate != 44100) {
-            throw std::runtime_error("Input audio sample rate must be 44100 Hz. Current: " + std::to_string(input_audio.sampleRate));
+        int required_sr = engine.GetSampleRate();
+        std::cout << "Model expects sample rate: " << required_sr << " Hz" << std::endl;
+
+        if (input_audio.sampleRate != required_sr) {
+            throw std::runtime_error("Input audio sample rate must be " + std::to_string(required_sr) + 
+                                     " Hz. Current: " + std::to_string(input_audio.sampleRate));
         }
 
         // 2. Check Channels & Auto-Expand Mono
@@ -99,6 +113,7 @@ int main(int argc, char* argv[]) {
              }
              input_audio.data = std::move(stereo_data);
              input_audio.channels = 2;
+             input_audio.samples *= 2;
         } else if (input_audio.channels != 2) {
              // We can either reject or try to process first 2 channels? 
              // Ideally reject to be safer, or warn.
@@ -150,9 +165,9 @@ int main(int argc, char* argv[]) {
 
             // Prepare AudioBuffer
             AudioBuffer output_audio_buf;
-            output_audio_buf.data = output_stems[i]; // Copy? AudioBuffer uses vector, simple move/copy
+            output_audio_buf.data = std::move(output_stems[i]); // Move to avoid copy
             output_audio_buf.channels = 2; // Output is always stereo
-            output_audio_buf.sampleRate = 44100;
+            output_audio_buf.sampleRate = required_sr;
             output_audio_buf.samples = output_stems[i].size();
             
             std::cout << "Saving output stem " << i << ": " << current_output_path << std::endl;

+ 1 - 0
include/mel_band_roformer/inference.h

@@ -30,6 +30,7 @@ public:
     // Get model's recommended inference defaults
     int GetDefaultChunkSize() const;
     int GetDefaultNumOverlap() const;
+    int GetSampleRate() const;
 
     // Static helper for Overlap-Add logic (matches Python exactly)
     // model_func: input [samples], output [stems][samples] (interleaved stereo)

+ 4 - 0
src/inference.cpp

@@ -46,6 +46,10 @@ int Inference::GetDefaultNumOverlap() const {
     return model_->GetDefaultNumOverlap();
 }
 
+int Inference::GetSampleRate() const {
+    return model_->GetSampleRate();
+}
+
 Inference::~Inference() {
     if (allocr_) ggml_gallocr_free(allocr_);
     if (ctx_) ggml_free(ctx_);

+ 3 - 0
src/model.cpp

@@ -82,6 +82,9 @@ void MelBandRoformer::LoadWeights(const std::string& path) {
 
     kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.mask_estimator_depth");
     if (kv_idx >= 0) mask_estimator_depth_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
+
+    kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.sample_rate");
+    if (kv_idx >= 0) sample_rate_ = (int)gguf_get_val_u32(ctx_gguf, kv_idx);
     
     // Inference defaults (optional, fallback to hardcoded values)
     kv_idx = gguf_find_key(ctx_gguf, "mel_band_roformer.default_chunk_size");

+ 4 - 2
src/model.h

@@ -50,7 +50,8 @@ public:
     int GetNumStems() const { return num_stems_; }
     bool GetSkipConnection() const { return skip_connection_; }
     bool GetSTFTNormalized() const { return stft_normalized_; }
-    bool GetZeroDC() const { return zero_dc_; }
+    int GetZeroDC() const { return zero_dc_; }
+    int GetSampleRate() const { return sample_rate_; }
     
     // Inference defaults (from GGUF, can be overridden at runtime)
     int GetDefaultChunkSize() const { return default_chunk_size_; }
@@ -142,7 +143,8 @@ private:
     bool stft_normalized_ = false;
     bool zero_dc_ = false;
     int mask_estimator_depth_ = 1;
-    
+    int sample_rate_ = 44100;
+
     // Inference defaults
     int default_chunk_size_ = 352800;
     int default_num_overlap_ = 2;