generate_test_data.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """
  2. Generate minimal test data for MelBandRoformer.cpp verification.
  3. This script generates ONLY the essential tensors needed for C++ tests:
  4. - input_audio.npy (for test_inference)
  5. - output_audio.npy (for test_inference)
  6. - band_split_in.npy (for test_component_bandsplit)
  7. - after_band_split.npy (for test_component_bandsplit, test_component_layers)
  8. - before_mask_est.npy (for test_component_layers, test_component_mask)
  9. - mask_est0.npy (for test_component_mask)
  10. - chunk_in.npy (for test_chunking_logic)
  11. - chunk_out.npy (for test_chunking_logic)
  12. Requirements:
  13. This script requires the Music-Source-Separation-Training repository:
  14. https://github.com/ZFTurbo/Music-Source-Separation-Training
  15. Clone it first:
  16. git clone https://github.com/ZFTurbo/Music-Source-Separation-Training.git
  17. Usage:
  18. python generate_test_data.py --model-repo /path/to/Music-Source-Separation-Training \\
  19. --audio test.wav --checkpoint model.ckpt --output test_data
  20. """
  21. import argparse
  22. import sys
  23. from pathlib import Path
  24. import numpy as np
  25. import torch
  26. import soundfile as sf
  27. import yaml
  28. from ml_collections import ConfigDict
  29. from einops import rearrange, pack, unpack
  30. # Model imports are deferred until we know the model-repo path
  31. # Model imports are deferred until we know the model-repo path
  32. MelBandRoformer = None
  33. BSRoformer = None
  34. pack_one = None
  35. unpack_one = None
  36. # Inference utility
  37. inference_func = None
  38. MODEL_REPO_URL = "https://github.com/ZFTurbo/Music-Source-Separation-Training"
  39. class MockModel(torch.nn.Module):
  40. """Identity model for testing chunking logic."""
  41. def __init__(self):
  42. super().__init__()
  43. def forward(self, x):
  44. # x shape: [Batch, Channels, Time] or [Batch, Time]
  45. # Return same as input (Identity)
  46. return x
  47. def load_model_module(model_repo_path: Path):
  48. """Dynamically load the MelBandRoformer model from the specified repository."""
  49. global MelBandRoformer, BSRoformer, pack_one, unpack_one, inference_func
  50. if not model_repo_path.exists():
  51. print("\n" + "=" * 70)
  52. print("ERROR: Model repository not found!")
  53. print("=" * 70)
  54. print(f"\nPath: {model_repo_path}")
  55. print("\nThis script requires the Music-Source-Separation-Training repository.")
  56. print("\nPlease clone it first:")
  57. print(f" git clone {MODEL_REPO_URL}")
  58. print(
  59. "\nThen run this script with --model-repo pointing to the cloned directory."
  60. )
  61. print("=" * 70)
  62. sys.exit(1)
  63. models_path = model_repo_path / "models"
  64. if not models_path.exists():
  65. print("\n" + "=" * 70)
  66. print("ERROR: Invalid repository structure!")
  67. print("=" * 70)
  68. print(f"\nThe 'models' directory was not found in: {model_repo_path}")
  69. print("=" * 70)
  70. sys.exit(1)
  71. # Add to path and import
  72. sys.path.insert(0, str(model_repo_path))
  73. # Mock loralib to allow importing model_utils without installing it
  74. from unittest.mock import MagicMock
  75. if "loralib" not in sys.modules:
  76. sys.modules["loralib"] = MagicMock()
  77. # Import from new structure (Music-Source-Separation-Training)
  78. try:
  79. from models.bs_roformer.mel_band_roformer import (
  80. MelBandRoformer as _MelBandRoformer,
  81. )
  82. from models.bs_roformer.mel_band_roformer import (
  83. pack_one as _pack_one,
  84. unpack_one as _unpack_one,
  85. )
  86. pack_one = _pack_one
  87. unpack_one = _unpack_one
  88. MelBandRoformer = _MelBandRoformer
  89. try:
  90. from models.bs_roformer.bs_roformer import BSRoformer as _BSRoformer
  91. BSRoformer = _BSRoformer
  92. except ImportError:
  93. print(" Warning: Could not import BSRoformer from model repo.")
  94. # Import demix from utils.model_utils
  95. from utils.model_utils import demix
  96. inference_func = demix
  97. print(f" Loaded model from: {model_repo_path}")
  98. return
  99. except ImportError as e:
  100. print("\n" + "=" * 70)
  101. print("ERROR: Failed to import model!")
  102. print("=" * 70)
  103. print(f"\nImport error: {e}")
  104. print(
  105. "\nPlease ensure the repository is complete and dependencies are installed."
  106. )
  107. sys.exit(1)
  108. def save_tensor(
  109. output_dir: Path, name: str, tensor, subdir: str = "activations"
  110. ) -> dict:
  111. """Save tensor to .npy file."""
  112. path = output_dir / subdir / f"{name}.npy"
  113. path.parent.mkdir(parents=True, exist_ok=True)
  114. if isinstance(tensor, torch.Tensor):
  115. tensor = tensor.detach().cpu()
  116. if tensor.dtype in [torch.int64, torch.int32, torch.bool]:
  117. tensor = tensor.float()
  118. tensor = tensor.numpy()
  119. if isinstance(tensor, np.ndarray) and tensor.dtype != np.float32:
  120. tensor = tensor.astype(np.float32)
  121. np.save(path, tensor)
  122. print(f" Saved {name}: shape={list(tensor.shape)}")
  123. return {"name": name, "shape": list(tensor.shape), "path": str(path)}
  124. def generate_chunking_data(output_dir: Path, config: ConfigDict):
  125. """Generate input/output data for verifying chunking logic."""
  126. print("\n[Chunking] Generating overlap-add debug data...")
  127. if inference_func is None:
  128. print(
  129. " Warning: Inference function not found, skipping chunking data generation."
  130. )
  131. return
  132. # Create Mock Model (Identity)
  133. model = MockModel()
  134. device = torch.device("cpu")
  135. # Create input: Ramp signal
  136. # Size > 2 chunks to test overlap logic
  137. # Use fixed values to match C++ test_chunking_logic.cpp (lines 76-77)
  138. chunk_size = 352800
  139. num_overlap = 2
  140. print(f" Chunk size: {chunk_size}, Overlap: {num_overlap}")
  141. total_len = chunk_size * 2 + 10000
  142. inputs = np.linspace(0, 1, total_len).astype(np.float32)
  143. # Make stereo [2, T]
  144. inputs = np.stack([inputs, inputs], axis=0)
  145. # Save input (C-order, transposed to [T, 2] for C++ ease if needed, but C++ load_npy handles it)
  146. save_tensor(output_dir, "chunk_in", inputs.T, subdir=".")
  147. # Run Inference
  148. mixture = torch.tensor(inputs, dtype=torch.float32)
  149. # demix(config, model, mix, device, model_type)
  150. # generic mode (not htdemucs) uses 'generic'
  151. # It returns dict {instr: waveform} or array
  152. res = inference_func(config, model, mixture, device, model_type="generic")
  153. if isinstance(res, dict):
  154. # Pick the first instrument
  155. first_key = list(res.keys())[0]
  156. output = res[first_key]
  157. else:
  158. output = res
  159. # Save output
  160. if isinstance(output, torch.Tensor):
  161. output = output.cpu().numpy()
  162. save_tensor(output_dir, "chunk_out", output.T, subdir=".")
  163. def generate_test_data(
  164. model_repo: str,
  165. audio_file: str,
  166. checkpoint: str,
  167. config_file: str,
  168. output_dir: str,
  169. audio_start: float = 2.0,
  170. audio_end: float = 5.0,
  171. ) -> int:
  172. """Generate test data for C++ verification."""
  173. # Load model module from specified repository
  174. model_repo_path = Path(model_repo)
  175. load_model_module(model_repo_path)
  176. output_path = Path(output_dir)
  177. output_path.mkdir(parents=True, exist_ok=True)
  178. print("=" * 70)
  179. print("MelBandRoformer Test Data Generator")
  180. print("=" * 70)
  181. # 1. Load config and model
  182. print(f"\n[1/4] Loading model from {checkpoint}")
  183. with open(config_file) as f:
  184. config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
  185. model_type = "mel_band"
  186. if "freqs_per_bands" in config.model:
  187. model_type = "bs"
  188. if BSRoformer is None:
  189. print(
  190. "Error: BSRoformer class not loaded but config looks like BS Roformer."
  191. )
  192. return 1
  193. model = BSRoformer(**dict(config.model))
  194. print(f" Architecture: Band Split Roformer")
  195. else:
  196. model = MelBandRoformer(**dict(config.model))
  197. print(f" Architecture: Mel-Band Roformer")
  198. state_dict = torch.load(checkpoint, map_location="cpu")
  199. # Handle checkpoint structure
  200. if "state_dict" in state_dict:
  201. state_dict = state_dict["state_dict"]
  202. elif "model" in state_dict:
  203. state_dict = state_dict["model"]
  204. model.load_state_dict(state_dict)
  205. model.eval()
  206. print(f" Config: depth={config.model.depth}, dim={config.model.dim}")
  207. # 2. Load audio
  208. print(f"\n[2/4] Loading audio ({audio_start}s - {audio_end}s) from {audio_file}")
  209. audio, sr = sf.read(audio_file)
  210. start_sample = int(audio_start * sr)
  211. end_sample = int(audio_end * sr)
  212. audio_segment = audio[start_sample:end_sample]
  213. if len(audio_segment.shape) == 1:
  214. audio_segment = np.stack([audio_segment, audio_segment], axis=-1)
  215. # [batch, channels, samples]
  216. audio_tensor = torch.tensor(audio_segment.T, dtype=torch.float32).unsqueeze(0)
  217. print(f" Audio shape: {audio_tensor.shape}")
  218. # 3. Run instrumented forward pass
  219. print("\n[3/4] Running instrumented forward pass...")
  220. captured = {}
  221. with torch.no_grad():
  222. device = audio_tensor.device
  223. raw_audio = audio_tensor
  224. if raw_audio.ndim == 2:
  225. raw_audio = rearrange(raw_audio, "b t -> b 1 t")
  226. batch, channels, raw_audio_length = raw_audio.shape
  227. istft_length = raw_audio_length
  228. # STFT
  229. raw_audio_packed, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
  230. stft_window = model.stft_window_fn(device=device)
  231. stft_repr = torch.stft(
  232. raw_audio_packed,
  233. **model.stft_kwargs,
  234. window=stft_window,
  235. return_complex=True,
  236. )
  237. stft_repr = torch.view_as_real(stft_repr)
  238. stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
  239. stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
  240. # Frequency indexing
  241. if model_type == "mel_band":
  242. batch_arange = torch.arange(batch, device=device)[..., None]
  243. x = stft_repr[batch_arange, model.freq_indices]
  244. x = rearrange(x, "b f t c -> b t (f c)")
  245. else:
  246. # BS Roformer: Direct usage
  247. x = stft_repr
  248. # If stft_repr is complex (view_as_real result: [b, f, t, 2])
  249. # BS model expects: [b, f, t, 2] -> rearrange to [b, t, (f * 2)]
  250. # Wait, bs_roformer.py: x = rearrange(x, 'b f t c -> b t (f c)')
  251. x = rearrange(x, "b f t c -> b t (f c)")
  252. # ===== CAPTURE: BandSplit Input =====
  253. captured["band_split_in"] = x.clone()
  254. # BandSplit
  255. x = model.band_split(x)
  256. # ===== CAPTURE: After BandSplit (= Transformer Input) =====
  257. captured["after_band_split"] = x.clone()
  258. # Transformer Layers
  259. for layer_idx, (time_transformer, freq_transformer) in enumerate(model.layers):
  260. # Time Transformer
  261. x = rearrange(x, "b t f d -> b f t d")
  262. x, ps = pack([x], "* t d")
  263. x = time_transformer(x)
  264. (x,) = unpack(x, ps, "* t d")
  265. x = rearrange(x, "b f t d -> b t f d")
  266. # Freq Transformer
  267. x, ps = pack([x], "* f d")
  268. x = freq_transformer(x)
  269. (x,) = unpack(x, ps, "* f d")
  270. # BS Roformer: Apply global final_norm after all transformer layers
  271. if model_type == "bs" and hasattr(model, "final_norm"):
  272. x = model.final_norm(x)
  273. # ===== CAPTURE: Before Mask Estimator (= Transformer Output) =====
  274. captured["before_mask_est"] = x.clone()
  275. # Mask Estimator (just first one for testing)
  276. mask0 = model.mask_estimators[0](x)
  277. # ===== CAPTURE: Mask Estimator Output =====
  278. captured["mask_est0"] = mask0.clone()
  279. # Continue with full forward pass for output
  280. num_stems = len(model.mask_estimators)
  281. masks = torch.stack([fn(x) for fn in model.mask_estimators], dim=1)
  282. masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
  283. stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
  284. stft_repr = torch.view_as_complex(stft_repr)
  285. masks = torch.view_as_complex(masks)
  286. masks = masks.type(stft_repr.dtype)
  287. from einops import repeat
  288. if model_type == "mel_band":
  289. scatter_indices = repeat(
  290. model.freq_indices,
  291. "f -> b n f t",
  292. b=batch,
  293. n=num_stems,
  294. t=stft_repr.shape[-1],
  295. )
  296. stft_repr_expanded_stems = repeat(
  297. stft_repr, "b 1 ... -> b n ...", n=num_stems
  298. )
  299. masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
  300. 2, scatter_indices, masks
  301. )
  302. denom = repeat(model.num_bands_per_freq, "f -> (f r) 1", r=channels)
  303. masks_averaged = masks_summed / denom.clamp(min=1e-8)
  304. stft_repr = stft_repr * masks_averaged
  305. else:
  306. # BS Roformer: Direct mask application
  307. # masks shape: [b, n, f, t, c] (rearranged above)
  308. # stft_repr shape: [b, 1, f, t, c] (rearranged above)
  309. # BS model output masks are often [b, n, f, t] (complex/real?)
  310. # Wait, bs_roformer.py:
  311. # masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
  312. # masks = rearrange(masks, 'b n t (f c) -> b n f t c', c = 2)
  313. # x = x * masks.sum(dim=1) # summation over stems? No, output separate stems.
  314. # return x * masks
  315. # So here: stft_repr * masks is correct.
  316. stft_repr = stft_repr * masks
  317. # ISTFT
  318. if model_type == "mel_band":
  319. stft_repr = rearrange(
  320. stft_repr, "b n (f s) t -> (b n s) f t", s=model.audio_channels
  321. )
  322. else:
  323. # BS Roformer: stft_repr is [b, n, (Freq*Stereo), t] (complex)
  324. # Unpack stereo and flatten batch/stems/stereo for istft
  325. stft_repr = rearrange(
  326. stft_repr, "b n (f s) t -> (b n s) f t", s=model.audio_channels
  327. )
  328. if getattr(model, "zero_dc", False):
  329. # Zero out DC component
  330. stft_repr = stft_repr.clone()
  331. stft_repr[:, 0, :] = 0.0
  332. recon_audio = torch.istft(
  333. stft_repr,
  334. **model.stft_kwargs,
  335. window=stft_window,
  336. return_complex=False,
  337. length=istft_length,
  338. )
  339. recon_audio = rearrange(
  340. recon_audio,
  341. "(b n s) t -> b n s t",
  342. b=batch,
  343. s=model.audio_channels,
  344. n=num_stems,
  345. )
  346. if num_stems == 1:
  347. recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
  348. captured["output_audio"] = recon_audio.clone()
  349. else:
  350. # Capture Stem 0 for verification
  351. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  352. # Capture Stem 0 for verification
  353. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  354. # 4. Generate Chunking Debug Data
  355. generate_chunking_data(output_path, config)
  356. # 5. Save tensors
  357. print(f"\n[4/5] Saving test data to {output_dir}")
  358. # Input audio
  359. save_tensor(output_path, "input_audio", audio_tensor)
  360. # Captured tensors
  361. for name, tensor in captured.items():
  362. save_tensor(output_path, name, tensor)
  363. # Verify outputs match normal forward pass
  364. print("\n[Verification] Checking output matches model.forward()...")
  365. with torch.no_grad():
  366. baseline = model(audio_tensor)
  367. if hasattr(model, "num_stems") and model.num_stems > 1:
  368. baseline = baseline[:, 0, :, :]
  369. diff = (baseline - captured["output_audio"]).abs()
  370. max_diff = diff.max().item()
  371. if max_diff > 1e-6:
  372. print(f" ✗ FAILED: max_diff = {max_diff:.2e}")
  373. return 1
  374. else:
  375. print(f" ✓ PASSED: max_diff = {max_diff:.2e}")
  376. print("\n" + "=" * 70)
  377. print("Test data generation complete!")
  378. print(f" Output: {output_dir}/activations/")
  379. print(f" Files: {len(captured) + 1} tensors")
  380. print("=" * 70)
  381. return 0
  382. def main():
  383. parser = argparse.ArgumentParser(
  384. description="Generate test data for MelBandRoformer.cpp",
  385. formatter_class=argparse.RawDescriptionHelpFormatter,
  386. epilog=f"""
  387. Requirements:
  388. This script requires the original Mel-Band-Roformer-Vocal-Model repository.
  389. Clone it first:
  390. git clone {MODEL_REPO_URL}
  391. Then specify the path with --model-repo.
  392. Example:
  393. python generate_test_data.py \\
  394. --model-repo /path/to/Mel-Band-Roformer-Vocal-Model \\
  395. --audio test.wav \\
  396. --checkpoint model.ckpt \\
  397. --output test_data
  398. """,
  399. )
  400. parser.add_argument(
  401. "--model-repo",
  402. required=True,
  403. help=f"Path to Mel-Band-Roformer-Vocal-Model repository (clone from {MODEL_REPO_URL})",
  404. )
  405. parser.add_argument("--audio", required=True, help="Input audio file (WAV)")
  406. parser.add_argument(
  407. "--checkpoint", required=True, help="Model checkpoint file (.ckpt)"
  408. )
  409. parser.add_argument(
  410. "--config",
  411. help="Model config YAML file (default: <model-repo>/configs/config_vocals_mel_band_roformer.yaml)",
  412. )
  413. parser.add_argument(
  414. "--output", default="test_data", help="Output directory for test data"
  415. )
  416. parser.add_argument(
  417. "--start",
  418. type=float,
  419. default=2.0,
  420. help="Audio start time in seconds (default: 2.0)",
  421. )
  422. parser.add_argument(
  423. "--end",
  424. type=float,
  425. default=5.0,
  426. help="Audio end time in seconds (default: 5.0)",
  427. )
  428. args = parser.parse_args()
  429. # Resolve paths
  430. model_repo_path = Path(args.model_repo).resolve()
  431. audio_path = Path(args.audio).resolve()
  432. checkpoint_path = Path(args.checkpoint).resolve()
  433. output_path = Path(args.output).resolve()
  434. # Config defaults to model-repo/configs/...
  435. if args.config:
  436. config_path = Path(args.config).resolve()
  437. else:
  438. config_path = (
  439. model_repo_path / "configs" / "config_vocals_mel_band_roformer.yaml"
  440. )
  441. # Validate paths
  442. if not audio_path.exists():
  443. print(f"Error: Audio file not found: {audio_path}")
  444. return 1
  445. if not checkpoint_path.exists():
  446. print(f"Error: Checkpoint not found: {checkpoint_path}")
  447. return 1
  448. if not config_path.exists():
  449. print(f"Error: Config not found: {config_path}")
  450. return 1
  451. return generate_test_data(
  452. str(model_repo_path),
  453. str(audio_path),
  454. str(checkpoint_path),
  455. str(config_path),
  456. str(output_path),
  457. args.start,
  458. args.end,
  459. )
  460. if __name__ == "__main__":
  461. sys.exit(main())