main.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #include "bs_roformer/inference.h"
  2. #include "bs_roformer/audio.h"
  3. #include <iostream>
  4. #include <string>
  5. #include <chrono>
  6. #include <cstdlib>
  7. void print_usage(const char* program_name) {
  8. std::cerr << "Usage: " << program_name << " <model.gguf> <input.wav> <output.wav> [options]" << std::endl;
  9. std::cerr << std::endl;
  10. std::cerr << "Options:" << std::endl;
  11. std::cerr << " --chunk-size <N> Chunk size in samples (default: from model, fallback 352800)" << std::endl;
  12. std::cerr << " --overlap <N> Number of overlaps for crossfade (default: from model, fallback 2)" << std::endl;
  13. std::cerr << " --help, -h Show this help message" << std::endl;
  14. }
  15. int main(int argc, char* argv[]) {
  16. // Default values (will be overridden by model defaults if not explicitly set)
  17. int chunk_size = -1; // -1 means use model default
  18. int num_overlap = -1; // -1 means use model default
  19. bool chunk_size_set = false;
  20. bool num_overlap_set = false;
  21. // Check for help flag first
  22. for (int i = 1; i < argc; ++i) {
  23. std::string arg = argv[i];
  24. if (arg == "--help" || arg == "-h") {
  25. print_usage(argv[0]);
  26. return 0;
  27. }
  28. }
  29. if (argc < 4) {
  30. print_usage(argv[0]);
  31. return 1;
  32. }
  33. std::string model_path = argv[1];
  34. std::string input_path = argv[2];
  35. std::string output_path = argv[3];
  36. // Parse optional arguments
  37. for (int i = 4; i < argc; ++i) {
  38. std::string arg = argv[i];
  39. if (arg == "--chunk-size" && i + 1 < argc) {
  40. try {
  41. chunk_size = std::stoi(argv[++i]);
  42. if (chunk_size <= 0) {
  43. std::cerr << "Error: chunk-size must be a positive integer" << std::endl;
  44. return 1;
  45. }
  46. chunk_size_set = true;
  47. } catch (...) {
  48. std::cerr << "Error: invalid chunk-size" << std::endl;
  49. return 1;
  50. }
  51. } else if (arg == "--overlap" && i + 1 < argc) {
  52. try {
  53. num_overlap = std::stoi(argv[++i]);
  54. if (num_overlap < 1) {
  55. std::cerr << "Error: overlap must be at least 1" << std::endl;
  56. return 1;
  57. }
  58. num_overlap_set = true;
  59. } catch (...) {
  60. std::cerr << "Error: invalid overlap" << std::endl;
  61. return 1;
  62. }
  63. } else {
  64. std::cerr << "Unknown option: " << arg << std::endl;
  65. print_usage(argv[0]);
  66. return 1;
  67. }
  68. }
  69. try {
  70. std::cout << "Initializing BSRoformer..." << std::endl;
  71. auto start_time = std::chrono::high_resolution_clock::now();
  72. Inference engine(model_path);
  73. // Use model defaults if not explicitly set by user
  74. if (!chunk_size_set) {
  75. chunk_size = engine.GetDefaultChunkSize();
  76. }
  77. if (!num_overlap_set) {
  78. num_overlap = engine.GetDefaultNumOverlap();
  79. }
  80. std::cout << "Loading audio: " << input_path << std::endl;
  81. AudioBuffer input_audio = AudioFile::Load(input_path);
  82. std::cout << "Audio loaded: " << input_audio.samples << " samples, "
  83. << input_audio.channels << " channels, "
  84. << input_audio.sampleRate << " Hz" << std::endl;
  85. // 1. Check Sample Rate
  86. int required_sr = engine.GetSampleRate();
  87. std::cout << "Model expects sample rate: " << required_sr << " Hz" << std::endl;
  88. if (input_audio.sampleRate != required_sr) {
  89. throw std::runtime_error("Input audio sample rate must be " + std::to_string(required_sr) +
  90. " Hz. Current: " + std::to_string(input_audio.sampleRate));
  91. }
  92. // 2. Check Channels & Auto-Expand Mono
  93. if (input_audio.channels == 1) {
  94. std::cout << "[Info] Input is Mono. Expanding to Stereo..." << std::endl;
  95. std::vector<float> stereo_data(input_audio.samples * 2);
  96. for(size_t i=0; i<input_audio.samples; ++i) {
  97. stereo_data[i*2 + 0] = input_audio.data[i];
  98. stereo_data[i*2 + 1] = input_audio.data[i];
  99. }
  100. input_audio.data = std::move(stereo_data);
  101. input_audio.channels = 2;
  102. input_audio.samples *= 2;
  103. } else if (input_audio.channels != 2) {
  104. // We can either reject or try to process first 2 channels?
  105. // Ideally reject to be safer, or warn.
  106. throw std::runtime_error("Input audio must be Stereo (2 channels) or Mono (1 channel). Current: " + std::to_string(input_audio.channels));
  107. }
  108. std::cout << "Processing with chunk_size=" << chunk_size
  109. << ", overlap=" << num_overlap << std::endl;
  110. auto process_start = std::chrono::high_resolution_clock::now();
  111. // Progress Bar Callback
  112. auto progress_callback = [](float progress) {
  113. int barWidth = 50;
  114. std::cout << "[";
  115. int pos = barWidth * progress;
  116. for (int i = 0; i < barWidth; ++i) {
  117. if (i < pos) std::cout << "=";
  118. else if (i == pos) std::cout << ">";
  119. else std::cout << " ";
  120. }
  121. std::cout << "] " << int(progress * 100.0) << " %\r";
  122. std::cout.flush();
  123. };
  124. std::vector<std::vector<float>> output_stems = engine.Process(input_audio.data, chunk_size, num_overlap, progress_callback);
  125. // Clear progress line
  126. std::cout << std::string(70, ' ') << "\r";
  127. auto process_end = std::chrono::high_resolution_clock::now();
  128. std::chrono::duration<double> diff = process_end - process_start;
  129. std::cout << "Processed in " << diff.count() << " seconds." << std::endl;
  130. int num_stems = output_stems.size();
  131. std::cout << "Model returned " << num_stems << " stems." << std::endl;
  132. for (int i = 0; i < num_stems; ++i) {
  133. // Prepare output filename
  134. std::string current_output_path = output_path;
  135. if (num_stems > 1) {
  136. // Insert _stem_i before extension
  137. size_t dot_pos = output_path.find_last_of(".");
  138. if (dot_pos != std::string::npos) {
  139. current_output_path = output_path.substr(0, dot_pos) + "_stem_" + std::to_string(i) + output_path.substr(dot_pos);
  140. } else {
  141. current_output_path = output_path + "_stem_" + std::to_string(i);
  142. }
  143. }
  144. // Prepare AudioBuffer
  145. AudioBuffer output_audio_buf;
  146. output_audio_buf.data = std::move(output_stems[i]); // Move to avoid copy
  147. output_audio_buf.channels = 2; // Output is always stereo
  148. output_audio_buf.sampleRate = required_sr;
  149. output_audio_buf.samples = output_audio_buf.data.size();
  150. std::cout << "Saving output stem " << i << ": " << current_output_path << std::endl;
  151. AudioFile::Save(current_output_path, output_audio_buf);
  152. }
  153. std::cout << "Done!" << std::endl;
  154. } catch (const std::exception& e) {
  155. std::cerr << "Error: " << e.what() << std::endl;
  156. return 1;
  157. }
  158. return 0;
  159. }