main.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #include "bs_roformer/inference.h"
  2. #include "bs_roformer/audio.h"
  3. #include <iostream>
  4. #include <string>
  5. #include <chrono>
  6. #include <cstdlib>
  7. void print_usage(const char* program_name) {
  8. std::cerr << "Usage: " << program_name << " <model.gguf> <input_audio> <output.wav> [options]" << std::endl;
  9. std::cerr << std::endl;
  10. std::cerr << "Input audio can be any common format (WAV, MP3, FLAC, OGG, etc.)" << std::endl;
  11. std::cerr << "Audio is automatically resampled to 44100 Hz if needed." << std::endl;
  12. std::cerr << std::endl;
  13. std::cerr << "Options:" << std::endl;
  14. std::cerr << " --chunk-size <N> Chunk size in samples (default: from model, fallback 352800)" << std::endl;
  15. std::cerr << " --overlap <N> Number of overlaps for crossfade (default: from model, fallback 2)" << std::endl;
  16. std::cerr << " --help, -h Show this help message" << std::endl;
  17. }
  18. int main(int argc, char* argv[]) {
  19. // Default values (will be overridden by model defaults if not explicitly set)
  20. int chunk_size = -1; // -1 means use model default
  21. int num_overlap = -1; // -1 means use model default
  22. bool chunk_size_set = false;
  23. bool num_overlap_set = false;
  24. // Check for help flag first
  25. for (int i = 1; i < argc; ++i) {
  26. std::string arg = argv[i];
  27. if (arg == "--help" || arg == "-h") {
  28. print_usage(argv[0]);
  29. return 0;
  30. }
  31. }
  32. if (argc < 4) {
  33. print_usage(argv[0]);
  34. return 1;
  35. }
  36. std::string model_path = argv[1];
  37. std::string input_path = argv[2];
  38. std::string output_path = argv[3];
  39. // Parse optional arguments
  40. for (int i = 4; i < argc; ++i) {
  41. std::string arg = argv[i];
  42. if (arg == "--chunk-size" && i + 1 < argc) {
  43. try {
  44. chunk_size = std::stoi(argv[++i]);
  45. if (chunk_size <= 0) {
  46. std::cerr << "Error: chunk-size must be a positive integer" << std::endl;
  47. return 1;
  48. }
  49. chunk_size_set = true;
  50. } catch (...) {
  51. std::cerr << "Error: invalid chunk-size" << std::endl;
  52. return 1;
  53. }
  54. } else if (arg == "--overlap" && i + 1 < argc) {
  55. try {
  56. num_overlap = std::stoi(argv[++i]);
  57. if (num_overlap < 1) {
  58. std::cerr << "Error: overlap must be at least 1" << std::endl;
  59. return 1;
  60. }
  61. num_overlap_set = true;
  62. } catch (...) {
  63. std::cerr << "Error: invalid overlap" << std::endl;
  64. return 1;
  65. }
  66. } else {
  67. std::cerr << "Unknown option: " << arg << std::endl;
  68. print_usage(argv[0]);
  69. return 1;
  70. }
  71. }
  72. try {
  73. std::cout << "Initializing BSRoformer..." << std::endl;
  74. auto start_time = std::chrono::high_resolution_clock::now();
  75. Inference engine(model_path);
  76. // Use model defaults if not explicitly set by user
  77. if (!chunk_size_set) {
  78. chunk_size = engine.GetDefaultChunkSize();
  79. }
  80. if (!num_overlap_set) {
  81. num_overlap = engine.GetDefaultNumOverlap();
  82. }
  83. std::cout << "Loading audio: " << input_path << std::endl;
  84. AudioBuffer input_audio = AudioFile::Load(input_path);
  85. std::cout << "Audio loaded: " << input_audio.samples << " samples, "
  86. << input_audio.channels << " channels, "
  87. << input_audio.sampleRate << " Hz" << std::endl;
  88. // AudioFile::Load automatically resamples to 44100 Hz and converts to stereo
  89. // No need for manual sample rate check or mono expansion
  90. std::cout << "Processing with chunk_size=" << chunk_size
  91. << ", overlap=" << num_overlap << std::endl;
  92. auto process_start = std::chrono::high_resolution_clock::now();
  93. // Progress Bar Callback
  94. auto progress_callback = [](float progress) {
  95. int barWidth = 50;
  96. std::cout << "[";
  97. int pos = barWidth * progress;
  98. for (int i = 0; i < barWidth; ++i) {
  99. if (i < pos) std::cout << "=";
  100. else if (i == pos) std::cout << ">";
  101. else std::cout << " ";
  102. }
  103. std::cout << "] " << int(progress * 100.0) << " %\r";
  104. std::cout.flush();
  105. };
  106. std::vector<std::vector<float>> output_stems = engine.Process(input_audio.data, chunk_size, num_overlap, progress_callback);
  107. // Clear progress line
  108. std::cout << std::string(70, ' ') << "\r";
  109. auto process_end = std::chrono::high_resolution_clock::now();
  110. std::chrono::duration<double> diff = process_end - process_start;
  111. std::cout << "Processed in " << diff.count() << " seconds." << std::endl;
  112. int num_stems = output_stems.size();
  113. std::cout << "Model returned " << num_stems << " stems." << std::endl;
  114. for (int i = 0; i < num_stems; ++i) {
  115. // Prepare output filename
  116. std::string current_output_path = output_path;
  117. if (num_stems > 1) {
  118. // Insert _stem_i before extension
  119. size_t dot_pos = output_path.find_last_of(".");
  120. if (dot_pos != std::string::npos) {
  121. current_output_path = output_path.substr(0, dot_pos) + "_stem_" + std::to_string(i) + output_path.substr(dot_pos);
  122. } else {
  123. current_output_path = output_path + "_stem_" + std::to_string(i);
  124. }
  125. }
  126. // Prepare AudioBuffer
  127. AudioBuffer output_audio_buf;
  128. output_audio_buf.data = std::move(output_stems[i]); // Move to avoid copy
  129. output_audio_buf.channels = 2; // Output is always stereo
  130. output_audio_buf.sampleRate = 44100;
  131. output_audio_buf.samples = output_audio_buf.data.size();
  132. std::cout << "Saving output stem " << i << ": " << current_output_path << std::endl;
  133. AudioFile::Save(current_output_path, output_audio_buf);
  134. }
  135. std::cout << "Done!" << std::endl;
  136. } catch (const std::exception& e) {
  137. std::cerr << "Error: " << e.what() << std::endl;
  138. return 1;
  139. }
  140. return 0;
  141. }