generate_test_data.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. """
  2. Generate minimal test data for MelBandRoformer.cpp verification.
  3. This script generates ONLY the essential tensors needed for C++ tests:
  4. - input_audio.npy (for test_inference)
  5. - output_audio.npy (for test_inference)
  6. - band_split_in.npy (for test_component_bandsplit)
  7. - after_band_split.npy (for test_component_bandsplit, test_component_layers)
  8. - before_mask_est.npy (for test_component_layers, test_component_mask)
  9. - mask_est0.npy (for test_component_mask)
  10. - chunk_in.npy (for test_chunking_logic)
  11. - chunk_out.npy (for test_chunking_logic)
  12. Requirements:
  13. This script requires the Music-Source-Separation-Training repository:
  14. https://github.com/ZFTurbo/Music-Source-Separation-Training
  15. Clone it first:
  16. git clone https://github.com/ZFTurbo/Music-Source-Separation-Training.git
  17. Usage:
  18. python generate_test_data.py --model-repo /path/to/Music-Source-Separation-Training \\
  19. --audio test.wav --checkpoint model.ckpt --output test_data
  20. """
  21. import argparse
  22. import sys
  23. from pathlib import Path
  24. import numpy as np
  25. import torch
  26. import soundfile as sf
  27. import yaml
  28. from ml_collections import ConfigDict
  29. from einops import rearrange, pack, unpack
  30. # Model imports are deferred until we know the model-repo path
  31. # Model imports are deferred until we know the model-repo path
  32. MelBandRoformer = None
  33. pack_one = None
  34. unpack_one = None
  35. # Inference utility
  36. inference_func = None
  37. MODEL_REPO_URL = "https://github.com/ZFTurbo/Music-Source-Separation-Training"
  38. class MockModel(torch.nn.Module):
  39. """Identity model for testing chunking logic."""
  40. def __init__(self):
  41. super().__init__()
  42. def forward(self, x):
  43. # x shape: [Batch, Channels, Time] or [Batch, Time]
  44. # Return same as input (Identity)
  45. return x
  46. def load_model_module(model_repo_path: Path):
  47. """Dynamically load the MelBandRoformer model from the specified repository."""
  48. global MelBandRoformer, pack_one, unpack_one, inference_func
  49. if not model_repo_path.exists():
  50. print("\n" + "=" * 70)
  51. print("ERROR: Model repository not found!")
  52. print("=" * 70)
  53. print(f"\nPath: {model_repo_path}")
  54. print("\nThis script requires the Music-Source-Separation-Training repository.")
  55. print("\nPlease clone it first:")
  56. print(f" git clone {MODEL_REPO_URL}")
  57. print(
  58. "\nThen run this script with --model-repo pointing to the cloned directory."
  59. )
  60. print("=" * 70)
  61. sys.exit(1)
  62. models_path = model_repo_path / "models"
  63. if not models_path.exists():
  64. print("\n" + "=" * 70)
  65. print("ERROR: Invalid repository structure!")
  66. print("=" * 70)
  67. print(f"\nThe 'models' directory was not found in: {model_repo_path}")
  68. print("=" * 70)
  69. sys.exit(1)
  70. # Add to path and import
  71. sys.path.insert(0, str(model_repo_path))
  72. # Mock loralib to allow importing model_utils without installing it
  73. from unittest.mock import MagicMock
  74. if "loralib" not in sys.modules:
  75. sys.modules["loralib"] = MagicMock()
  76. # Import from new structure (Music-Source-Separation-Training)
  77. try:
  78. from models.bs_roformer.mel_band_roformer import (
  79. MelBandRoformer as _MelBandRoformer,
  80. )
  81. from models.bs_roformer.mel_band_roformer import (
  82. pack_one as _pack_one,
  83. unpack_one as _unpack_one,
  84. )
  85. pack_one = _pack_one
  86. unpack_one = _unpack_one
  87. MelBandRoformer = _MelBandRoformer
  88. # Import demix from utils.model_utils
  89. from utils.model_utils import demix
  90. inference_func = demix
  91. print(f" Loaded model from: {model_repo_path}")
  92. return
  93. except ImportError as e:
  94. print("\n" + "=" * 70)
  95. print("ERROR: Failed to import model!")
  96. print("=" * 70)
  97. print(f"\nImport error: {e}")
  98. print(
  99. "\nPlease ensure the repository is complete and dependencies are installed."
  100. )
  101. sys.exit(1)
  102. def save_tensor(
  103. output_dir: Path, name: str, tensor, subdir: str = "activations"
  104. ) -> dict:
  105. """Save tensor to .npy file."""
  106. path = output_dir / subdir / f"{name}.npy"
  107. path.parent.mkdir(parents=True, exist_ok=True)
  108. if isinstance(tensor, torch.Tensor):
  109. tensor = tensor.detach().cpu()
  110. if tensor.dtype in [torch.int64, torch.int32, torch.bool]:
  111. tensor = tensor.float()
  112. tensor = tensor.numpy()
  113. if isinstance(tensor, np.ndarray) and tensor.dtype != np.float32:
  114. tensor = tensor.astype(np.float32)
  115. np.save(path, tensor)
  116. print(f" Saved {name}: shape={list(tensor.shape)}")
  117. return {"name": name, "shape": list(tensor.shape), "path": str(path)}
  118. def generate_chunking_data(output_dir: Path, config: ConfigDict):
  119. """Generate input/output data for verifying chunking logic."""
  120. print("\n[Chunking] Generating overlap-add debug data...")
  121. if inference_func is None:
  122. print(
  123. " Warning: Inference function not found, skipping chunking data generation."
  124. )
  125. return
  126. # Create Mock Model (Identity)
  127. model = MockModel()
  128. device = torch.device("cpu")
  129. # Create input: Ramp signal
  130. # Size > 2 chunks to test overlap logic
  131. # Use fixed values to match C++ test_chunking_logic.cpp (lines 76-77)
  132. chunk_size = 352800
  133. num_overlap = 2
  134. print(f" Chunk size: {chunk_size}, Overlap: {num_overlap}")
  135. total_len = chunk_size * 2 + 10000
  136. inputs = np.linspace(0, 1, total_len).astype(np.float32)
  137. # Make stereo [2, T]
  138. inputs = np.stack([inputs, inputs], axis=0)
  139. # Save input (C-order, transposed to [T, 2] for C++ ease if needed, but C++ load_npy handles it)
  140. save_tensor(output_dir, "chunk_in", inputs.T, subdir=".")
  141. # Run Inference
  142. mixture = torch.tensor(inputs, dtype=torch.float32)
  143. # demix(config, model, mix, device, model_type)
  144. # generic mode (not htdemucs) uses 'generic'
  145. # It returns dict {instr: waveform} or array
  146. res = inference_func(config, model, mixture, device, model_type="generic")
  147. if isinstance(res, dict):
  148. # Pick the first instrument
  149. first_key = list(res.keys())[0]
  150. output = res[first_key]
  151. else:
  152. output = res
  153. # Save output
  154. if isinstance(output, torch.Tensor):
  155. output = output.cpu().numpy()
  156. save_tensor(output_dir, "chunk_out", output.T, subdir=".")
  157. def generate_test_data(
  158. model_repo: str,
  159. audio_file: str,
  160. checkpoint: str,
  161. config_file: str,
  162. output_dir: str,
  163. audio_start: float = 2.0,
  164. audio_end: float = 5.0,
  165. ) -> int:
  166. """Generate test data for C++ verification."""
  167. # Load model module from specified repository
  168. model_repo_path = Path(model_repo)
  169. load_model_module(model_repo_path)
  170. output_path = Path(output_dir)
  171. output_path.mkdir(parents=True, exist_ok=True)
  172. print("=" * 70)
  173. print("MelBandRoformer Test Data Generator")
  174. print("=" * 70)
  175. # 1. Load config and model
  176. print(f"\n[1/4] Loading model from {checkpoint}")
  177. with open(config_file) as f:
  178. config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
  179. model = MelBandRoformer(**dict(config.model))
  180. state_dict = torch.load(checkpoint, map_location="cpu")
  181. model.load_state_dict(state_dict)
  182. model.eval()
  183. print(f" Config: depth={config.model.depth}, dim={config.model.dim}")
  184. # 2. Load audio
  185. print(f"\n[2/4] Loading audio ({audio_start}s - {audio_end}s) from {audio_file}")
  186. audio, sr = sf.read(audio_file)
  187. start_sample = int(audio_start * sr)
  188. end_sample = int(audio_end * sr)
  189. audio_segment = audio[start_sample:end_sample]
  190. if len(audio_segment.shape) == 1:
  191. audio_segment = np.stack([audio_segment, audio_segment], axis=-1)
  192. # [batch, channels, samples]
  193. audio_tensor = torch.tensor(audio_segment.T, dtype=torch.float32).unsqueeze(0)
  194. print(f" Audio shape: {audio_tensor.shape}")
  195. # 3. Run instrumented forward pass
  196. print("\n[3/4] Running instrumented forward pass...")
  197. captured = {}
  198. with torch.no_grad():
  199. device = audio_tensor.device
  200. raw_audio = audio_tensor
  201. if raw_audio.ndim == 2:
  202. raw_audio = rearrange(raw_audio, "b t -> b 1 t")
  203. batch, channels, raw_audio_length = raw_audio.shape
  204. istft_length = raw_audio_length if model.match_input_audio_length else None
  205. # STFT
  206. raw_audio_packed, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
  207. stft_window = model.stft_window_fn(device=device)
  208. stft_repr = torch.stft(
  209. raw_audio_packed,
  210. **model.stft_kwargs,
  211. window=stft_window,
  212. return_complex=True,
  213. )
  214. stft_repr = torch.view_as_real(stft_repr)
  215. stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
  216. stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
  217. # Frequency indexing
  218. batch_arange = torch.arange(batch, device=device)[..., None]
  219. x = stft_repr[batch_arange, model.freq_indices]
  220. x = rearrange(x, "b f t c -> b t (f c)")
  221. # ===== CAPTURE: BandSplit Input =====
  222. captured["band_split_in"] = x.clone()
  223. # BandSplit
  224. x = model.band_split(x)
  225. # ===== CAPTURE: After BandSplit (= Transformer Input) =====
  226. captured["after_band_split"] = x.clone()
  227. # Transformer Layers
  228. for layer_idx, (time_transformer, freq_transformer) in enumerate(model.layers):
  229. # Time Transformer
  230. x = rearrange(x, "b t f d -> b f t d")
  231. x, ps = pack([x], "* t d")
  232. x = time_transformer(x)
  233. (x,) = unpack(x, ps, "* t d")
  234. x = rearrange(x, "b f t d -> b t f d")
  235. # Freq Transformer
  236. x, ps = pack([x], "* f d")
  237. x = freq_transformer(x)
  238. (x,) = unpack(x, ps, "* f d")
  239. # ===== CAPTURE: Before Mask Estimator (= Transformer Output) =====
  240. captured["before_mask_est"] = x.clone()
  241. # Mask Estimator (just first one for testing)
  242. mask0 = model.mask_estimators[0](x)
  243. # ===== CAPTURE: Mask Estimator Output =====
  244. captured["mask_est0"] = mask0.clone()
  245. # Continue with full forward pass for output
  246. num_stems = len(model.mask_estimators)
  247. masks = torch.stack([fn(x) for fn in model.mask_estimators], dim=1)
  248. masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
  249. stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
  250. stft_repr = torch.view_as_complex(stft_repr)
  251. masks = torch.view_as_complex(masks)
  252. masks = masks.type(stft_repr.dtype)
  253. from einops import repeat
  254. scatter_indices = repeat(
  255. model.freq_indices,
  256. "f -> b n f t",
  257. b=batch,
  258. n=num_stems,
  259. t=stft_repr.shape[-1],
  260. )
  261. stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
  262. masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
  263. 2, scatter_indices, masks
  264. )
  265. denom = repeat(model.num_bands_per_freq, "f -> (f r) 1", r=channels)
  266. masks_averaged = masks_summed / denom.clamp(min=1e-8)
  267. stft_repr = stft_repr * masks_averaged
  268. # ISTFT
  269. stft_repr = rearrange(
  270. stft_repr, "b n (f s) t -> (b n s) f t", s=model.audio_channels
  271. )
  272. if getattr(model, "zero_dc", False):
  273. # Zero out DC component
  274. stft_repr = stft_repr.clone()
  275. stft_repr[:, 0, :] = 0.0
  276. recon_audio = torch.istft(
  277. stft_repr,
  278. **model.stft_kwargs,
  279. window=stft_window,
  280. return_complex=False,
  281. length=istft_length,
  282. )
  283. recon_audio = rearrange(
  284. recon_audio,
  285. "(b n s) t -> b n s t",
  286. b=batch,
  287. s=model.audio_channels,
  288. n=num_stems,
  289. )
  290. if num_stems == 1:
  291. recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
  292. captured["output_audio"] = recon_audio.clone()
  293. else:
  294. # Capture Stem 0 for verification
  295. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  296. # Capture Stem 0 for verification
  297. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  298. # 4. Generate Chunking Debug Data
  299. generate_chunking_data(output_path, config)
  300. # 5. Save tensors
  301. print(f"\n[4/5] Saving test data to {output_dir}")
  302. # Input audio
  303. save_tensor(output_path, "input_audio", audio_tensor)
  304. # Captured tensors
  305. for name, tensor in captured.items():
  306. save_tensor(output_path, name, tensor)
  307. # Verify outputs match normal forward pass
  308. print("\n[Verification] Checking output matches model.forward()...")
  309. with torch.no_grad():
  310. baseline = model(audio_tensor)
  311. if hasattr(model, "num_stems") and model.num_stems > 1:
  312. baseline = baseline[:, 0, :, :]
  313. diff = (baseline - captured["output_audio"]).abs()
  314. max_diff = diff.max().item()
  315. if max_diff > 1e-6:
  316. print(f" ✗ FAILED: max_diff = {max_diff:.2e}")
  317. return 1
  318. else:
  319. print(f" ✓ PASSED: max_diff = {max_diff:.2e}")
  320. print("\n" + "=" * 70)
  321. print("Test data generation complete!")
  322. print(f" Output: {output_dir}/activations/")
  323. print(f" Files: {len(captured) + 1} tensors")
  324. print("=" * 70)
  325. return 0
  326. def main():
  327. parser = argparse.ArgumentParser(
  328. description="Generate test data for MelBandRoformer.cpp",
  329. formatter_class=argparse.RawDescriptionHelpFormatter,
  330. epilog=f"""
  331. Requirements:
  332. This script requires the original Mel-Band-Roformer-Vocal-Model repository.
  333. Clone it first:
  334. git clone {MODEL_REPO_URL}
  335. Then specify the path with --model-repo.
  336. Example:
  337. python generate_test_data.py \\
  338. --model-repo /path/to/Mel-Band-Roformer-Vocal-Model \\
  339. --audio test.wav \\
  340. --checkpoint model.ckpt \\
  341. --output test_data
  342. """,
  343. )
  344. parser.add_argument(
  345. "--model-repo",
  346. required=True,
  347. help=f"Path to Mel-Band-Roformer-Vocal-Model repository (clone from {MODEL_REPO_URL})",
  348. )
  349. parser.add_argument("--audio", required=True, help="Input audio file (WAV)")
  350. parser.add_argument(
  351. "--checkpoint", required=True, help="Model checkpoint file (.ckpt)"
  352. )
  353. parser.add_argument(
  354. "--config",
  355. help="Model config YAML file (default: <model-repo>/configs/config_vocals_mel_band_roformer.yaml)",
  356. )
  357. parser.add_argument(
  358. "--output", default="test_data", help="Output directory for test data"
  359. )
  360. parser.add_argument(
  361. "--start",
  362. type=float,
  363. default=2.0,
  364. help="Audio start time in seconds (default: 2.0)",
  365. )
  366. parser.add_argument(
  367. "--end",
  368. type=float,
  369. default=5.0,
  370. help="Audio end time in seconds (default: 5.0)",
  371. )
  372. args = parser.parse_args()
  373. # Resolve paths
  374. model_repo_path = Path(args.model_repo).resolve()
  375. audio_path = Path(args.audio).resolve()
  376. checkpoint_path = Path(args.checkpoint).resolve()
  377. output_path = Path(args.output).resolve()
  378. # Config defaults to model-repo/configs/...
  379. if args.config:
  380. config_path = Path(args.config).resolve()
  381. else:
  382. config_path = (
  383. model_repo_path / "configs" / "config_vocals_mel_band_roformer.yaml"
  384. )
  385. # Validate paths
  386. if not audio_path.exists():
  387. print(f"Error: Audio file not found: {audio_path}")
  388. return 1
  389. if not checkpoint_path.exists():
  390. print(f"Error: Checkpoint not found: {checkpoint_path}")
  391. return 1
  392. if not config_path.exists():
  393. print(f"Error: Config not found: {config_path}")
  394. return 1
  395. return generate_test_data(
  396. str(model_repo_path),
  397. str(audio_path),
  398. str(checkpoint_path),
  399. str(config_path),
  400. str(output_path),
  401. args.start,
  402. args.end,
  403. )
  404. if __name__ == "__main__":
  405. sys.exit(main())