main.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #include "mel_band_roformer/inference.h"
  2. #include "mel_band_roformer/audio.h"
  3. #include <iostream>
  4. #include <string>
  5. #include <chrono>
  6. #include <cstdlib>
  7. void print_usage(const char* program_name) {
  8. std::cerr << "Usage: " << program_name << " <model.gguf> <input.wav> <output.wav> [options]" << std::endl;
  9. std::cerr << std::endl;
  10. std::cerr << "Options:" << std::endl;
  11. std::cerr << " --chunk-size <N> Chunk size in samples (default: from model, fallback 352800)" << std::endl;
  12. std::cerr << " --overlap <N> Number of overlaps for crossfade (default: from model, fallback 2)" << std::endl;
  13. std::cerr << " --help, -h Show this help message" << std::endl;
  14. }
  15. int main(int argc, char* argv[]) {
  16. // Default values (will be overridden by model defaults if not explicitly set)
  17. int chunk_size = -1; // -1 means use model default
  18. int num_overlap = -1; // -1 means use model default
  19. bool chunk_size_set = false;
  20. bool num_overlap_set = false;
  21. // Check for help flag first
  22. for (int i = 1; i < argc; ++i) {
  23. std::string arg = argv[i];
  24. if (arg == "--help" || arg == "-h") {
  25. print_usage(argv[0]);
  26. return 0;
  27. }
  28. }
  29. if (argc < 4) {
  30. print_usage(argv[0]);
  31. return 1;
  32. }
  33. std::string model_path = argv[1];
  34. std::string input_path = argv[2];
  35. std::string output_path = argv[3];
  36. // Parse optional arguments
  37. for (int i = 4; i < argc; ++i) {
  38. std::string arg = argv[i];
  39. if (arg == "--chunk-size" && i + 1 < argc) {
  40. chunk_size = std::atoi(argv[++i]);
  41. if (chunk_size <= 0) {
  42. std::cerr << "Error: chunk-size must be a positive integer" << std::endl;
  43. return 1;
  44. }
  45. chunk_size_set = true;
  46. } else if (arg == "--overlap" && i + 1 < argc) {
  47. num_overlap = std::atoi(argv[++i]);
  48. if (num_overlap < 1) {
  49. std::cerr << "Error: overlap must be at least 1" << std::endl;
  50. return 1;
  51. }
  52. num_overlap_set = true;
  53. } else {
  54. std::cerr << "Unknown option: " << arg << std::endl;
  55. print_usage(argv[0]);
  56. return 1;
  57. }
  58. }
  59. try {
  60. std::cout << "Initializing MelBandRoformer..." << std::endl;
  61. auto start_time = std::chrono::high_resolution_clock::now();
  62. Inference engine(model_path);
  63. // Use model defaults if not explicitly set by user
  64. if (!chunk_size_set) {
  65. chunk_size = engine.GetDefaultChunkSize();
  66. }
  67. if (!num_overlap_set) {
  68. num_overlap = engine.GetDefaultNumOverlap();
  69. }
  70. std::cout << "Loading audio: " << input_path << std::endl;
  71. AudioBuffer input_audio = AudioFile::Load(input_path);
  72. std::cout << "Audio loaded: " << input_audio.samples << " samples, "
  73. << input_audio.channels << " channels, "
  74. << input_audio.sampleRate << " Hz" << std::endl;
  75. // 1. Check Sample Rate
  76. if (input_audio.sampleRate != 44100) {
  77. throw std::runtime_error("Input audio sample rate must be 44100 Hz. Current: " + std::to_string(input_audio.sampleRate));
  78. }
  79. // 2. Check Channels & Auto-Expand Mono
  80. if (input_audio.channels == 1) {
  81. std::cout << "[Info] Input is Mono. Expanding to Stereo..." << std::endl;
  82. std::vector<float> stereo_data(input_audio.samples * 2);
  83. for(size_t i=0; i<input_audio.samples; ++i) {
  84. stereo_data[i*2 + 0] = input_audio.data[i];
  85. stereo_data[i*2 + 1] = input_audio.data[i];
  86. }
  87. input_audio.data = std::move(stereo_data);
  88. input_audio.channels = 2;
  89. } else if (input_audio.channels != 2) {
  90. // We can either reject or try to process first 2 channels?
  91. // Ideally reject to be safer, or warn.
  92. throw std::runtime_error("Input audio must be Stereo (2 channels) or Mono (1 channel). Current: " + std::to_string(input_audio.channels));
  93. }
  94. std::cout << "Processing with chunk_size=" << chunk_size
  95. << ", overlap=" << num_overlap << std::endl;
  96. auto process_start = std::chrono::high_resolution_clock::now();
  97. // Progress Bar Callback
  98. auto progress_callback = [](float progress) {
  99. int barWidth = 50;
  100. std::cout << "[";
  101. int pos = barWidth * progress;
  102. for (int i = 0; i < barWidth; ++i) {
  103. if (i < pos) std::cout << "=";
  104. else if (i == pos) std::cout << ">";
  105. else std::cout << " ";
  106. }
  107. std::cout << "] " << int(progress * 100.0) << " %\r";
  108. std::cout.flush();
  109. };
  110. std::vector<std::vector<float>> output_stems = engine.Process(input_audio.data, chunk_size, num_overlap, progress_callback);
  111. // Clear progress line
  112. std::cout << std::string(70, ' ') << "\r";
  113. auto process_end = std::chrono::high_resolution_clock::now();
  114. std::chrono::duration<double> diff = process_end - process_start;
  115. std::cout << "Processed in " << diff.count() << " seconds." << std::endl;
  116. int num_stems = output_stems.size();
  117. std::cout << "Model returned " << num_stems << " stems." << std::endl;
  118. for (int i = 0; i < num_stems; ++i) {
  119. // Prepare output filename
  120. std::string current_output_path = output_path;
  121. if (num_stems > 1) {
  122. // Insert _stem_i before extension
  123. size_t dot_pos = output_path.find_last_of(".");
  124. if (dot_pos != std::string::npos) {
  125. current_output_path = output_path.substr(0, dot_pos) + "_stem_" + std::to_string(i) + output_path.substr(dot_pos);
  126. } else {
  127. current_output_path = output_path + "_stem_" + std::to_string(i);
  128. }
  129. }
  130. // Prepare AudioBuffer
  131. AudioBuffer output_audio_buf;
  132. output_audio_buf.data = output_stems[i]; // Copy? AudioBuffer uses vector, simple move/copy
  133. output_audio_buf.channels = 2; // Output is always stereo
  134. output_audio_buf.sampleRate = 44100;
  135. output_audio_buf.samples = output_stems[i].size();
  136. std::cout << "Saving output stem " << i << ": " << current_output_path << std::endl;
  137. AudioFile::Save(current_output_path, output_audio_buf);
  138. }
  139. std::cout << "Done!" << std::endl;
  140. } catch (const std::exception& e) {
  141. std::cerr << "Error: " << e.what() << std::endl;
  142. return 1;
  143. }
  144. return 0;
  145. }