generate_test_data.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. """
  2. Generate minimal test data for MelBandRoformer.cpp verification.
  3. This script generates ONLY the essential tensors needed for C++ tests:
  4. - input_audio.npy (for test_inference)
  5. - output_audio.npy (for test_inference)
  6. - band_split_in.npy (for test_component_bandsplit)
  7. - after_band_split.npy (for test_component_bandsplit, test_component_layers)
  8. - before_mask_est.npy (for test_component_layers, test_component_mask)
  9. - mask_est0.npy (for test_component_mask)
  10. - chunk_in.npy (for test_chunking_logic)
  11. - chunk_out.npy (for test_chunking_logic)
  12. Requirements:
  13. This script requires the Music-Source-Separation-Training repository:
  14. https://github.com/ZFTurbo/Music-Source-Separation-Training
  15. Clone it first:
  16. git clone https://github.com/ZFTurbo/Music-Source-Separation-Training.git
  17. Usage:
  18. python generate_test_data.py --model-repo /path/to/Music-Source-Separation-Training \\
  19. --audio test.wav --checkpoint model.ckpt --output test_data
  20. """
  21. import argparse
  22. import sys
  23. from pathlib import Path
  24. import numpy as np
  25. import torch
  26. import soundfile as sf
  27. import yaml
  28. from ml_collections import ConfigDict
  29. from einops import rearrange, pack, unpack
  30. # Model imports are deferred until we know the model-repo path
  31. # Model imports are deferred until we know the model-repo path
  32. MelBandRoformer = None
  33. BSRoformer = None
  34. pack_one = None
  35. unpack_one = None
  36. # Inference utility
  37. inference_func = None
  38. MODEL_REPO_URL = "https://github.com/ZFTurbo/Music-Source-Separation-Training"
  39. class MockModel(torch.nn.Module):
  40. """Identity model for testing chunking logic."""
  41. def __init__(self):
  42. super().__init__()
  43. def forward(self, x):
  44. # x shape: [Batch, Channels, Time] or [Batch, Time]
  45. # Return same as input (Identity)
  46. return x
  47. def load_model_module(model_repo_path: Path):
  48. """Dynamically load the MelBandRoformer model from the specified repository."""
  49. global MelBandRoformer, BSRoformer, pack_one, unpack_one, inference_func
  50. if not model_repo_path.exists():
  51. print("\n" + "=" * 70)
  52. print("ERROR: Model repository not found!")
  53. print("=" * 70)
  54. print(f"\nPath: {model_repo_path}")
  55. print("\nThis script requires the Music-Source-Separation-Training repository.")
  56. print("\nPlease clone it first:")
  57. print(f" git clone {MODEL_REPO_URL}")
  58. print(
  59. "\nThen run this script with --model-repo pointing to the cloned directory."
  60. )
  61. print("=" * 70)
  62. sys.exit(1)
  63. models_path = model_repo_path / "models"
  64. if not models_path.exists():
  65. print("\n" + "=" * 70)
  66. print("ERROR: Invalid repository structure!")
  67. print("=" * 70)
  68. print(f"\nThe 'models' directory was not found in: {model_repo_path}")
  69. print("=" * 70)
  70. sys.exit(1)
  71. # Add to path and import
  72. sys.path.insert(0, str(model_repo_path))
  73. # Mock loralib to allow importing model_utils without installing it
  74. from unittest.mock import MagicMock
  75. if "loralib" not in sys.modules:
  76. sys.modules["loralib"] = MagicMock()
  77. # Import from new structure (Music-Source-Separation-Training)
  78. try:
  79. from models.bs_roformer.mel_band_roformer import (
  80. MelBandRoformer as _MelBandRoformer,
  81. )
  82. from models.bs_roformer.mel_band_roformer import (
  83. pack_one as _pack_one,
  84. unpack_one as _unpack_one,
  85. )
  86. pack_one = _pack_one
  87. unpack_one = _unpack_one
  88. MelBandRoformer = _MelBandRoformer
  89. try:
  90. from models.bs_roformer.bs_roformer import BSRoformer as _BSRoformer
  91. BSRoformer = _BSRoformer
  92. except ImportError:
  93. print(" Warning: Could not import BSRoformer from model repo.")
  94. # Import demix from utils.model_utils
  95. from utils.model_utils import demix
  96. inference_func = demix
  97. print(f" Loaded model from: {model_repo_path}")
  98. return
  99. except ImportError as e:
  100. print("\n" + "=" * 70)
  101. print("ERROR: Failed to import model!")
  102. print("=" * 70)
  103. print(f"\nImport error: {e}")
  104. print(
  105. "\nPlease ensure the repository is complete and dependencies are installed."
  106. )
  107. sys.exit(1)
  108. def save_tensor(
  109. output_dir: Path, name: str, tensor, subdir: str = "activations"
  110. ) -> dict:
  111. """Save tensor to .npy file."""
  112. path = output_dir / subdir / f"{name}.npy"
  113. path.parent.mkdir(parents=True, exist_ok=True)
  114. if isinstance(tensor, torch.Tensor):
  115. tensor = tensor.detach().cpu()
  116. if tensor.dtype in [torch.int64, torch.int32, torch.bool]:
  117. tensor = tensor.float()
  118. tensor = tensor.numpy()
  119. if isinstance(tensor, np.ndarray) and tensor.dtype != np.float32:
  120. tensor = tensor.astype(np.float32)
  121. np.save(path, tensor)
  122. print(f" Saved {name}: shape={list(tensor.shape)}")
  123. return {"name": name, "shape": list(tensor.shape), "path": str(path)}
  124. def generate_chunking_data(output_dir: Path, config: ConfigDict):
  125. """Generate input/output data for verifying chunking logic."""
  126. print("\n[Chunking] Generating overlap-add debug data...")
  127. if inference_func is None:
  128. print(
  129. " Warning: Inference function not found, skipping chunking data generation."
  130. )
  131. return
  132. # Create Mock Model (Identity)
  133. model = MockModel()
  134. device = torch.device("cpu")
  135. # Create input: Ramp signal
  136. # Size > 2 chunks to test overlap logic
  137. # Use fixed values to match C++ test_chunking_logic.cpp (lines 76-77)
  138. chunk_size = 352800
  139. num_overlap = 2
  140. print(f" Chunk size: {chunk_size}, Overlap: {num_overlap}")
  141. total_len = chunk_size * 2 + 10000
  142. inputs = np.linspace(0, 1, total_len).astype(np.float32)
  143. # Make stereo [2, T]
  144. inputs = np.stack([inputs, inputs], axis=0)
  145. # Save input (C-order, transposed to [T, 2] for C++ ease if needed, but C++ load_npy handles it)
  146. save_tensor(output_dir, "chunk_in", inputs.T, subdir=".")
  147. # Run Inference
  148. mixture = torch.tensor(inputs, dtype=torch.float32)
  149. # demix(config, model, mix, device, model_type)
  150. # generic mode (not htdemucs) uses 'generic'
  151. # It returns dict {instr: waveform} or array
  152. res = inference_func(config, model, mixture, device, model_type="generic")
  153. if isinstance(res, dict):
  154. # Pick the first instrument
  155. first_key = list(res.keys())[0]
  156. output = res[first_key]
  157. else:
  158. output = res
  159. # Save output
  160. if isinstance(output, torch.Tensor):
  161. output = output.cpu().numpy()
  162. save_tensor(output_dir, "chunk_out", output.T, subdir=".")
  163. def generate_test_data(
  164. model_repo: str,
  165. audio_file: str,
  166. checkpoint: str,
  167. config_file: str,
  168. output_dir: str,
  169. audio_start: float = 2.0,
  170. audio_end: float = 5.0,
  171. ) -> int:
  172. """Generate test data for C++ verification."""
  173. # Load model module from specified repository
  174. model_repo_path = Path(model_repo)
  175. load_model_module(model_repo_path)
  176. output_path = Path(output_dir)
  177. output_path.mkdir(parents=True, exist_ok=True)
  178. print("=" * 70)
  179. print("MelBandRoformer Test Data Generator")
  180. print("=" * 70)
  181. # 1. Load config and model
  182. print(f"\n[1/4] Loading model from {checkpoint}")
  183. with open(config_file) as f:
  184. config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
  185. model_type = "mel_band"
  186. if "freqs_per_bands" in config.model:
  187. model_type = "bs"
  188. if BSRoformer is None:
  189. print(
  190. "Error: BSRoformer class not loaded but config looks like BS Roformer."
  191. )
  192. return 1
  193. model = BSRoformer(**dict(config.model))
  194. print(f" Architecture: Band Split Roformer")
  195. else:
  196. model = MelBandRoformer(**dict(config.model))
  197. print(f" Architecture: Mel-Band Roformer")
  198. state_dict = torch.load(checkpoint, map_location="cpu")
  199. # Handle checkpoint structure
  200. if "state_dict" in state_dict:
  201. state_dict = state_dict["state_dict"]
  202. elif "model" in state_dict:
  203. state_dict = state_dict["model"]
  204. model.load_state_dict(state_dict)
  205. model.eval()
  206. print(f" Config: depth={config.model.depth}, dim={config.model.dim}")
  207. # 2. Load audio
  208. print(f"\n[2/4] Loading audio ({audio_start}s - {audio_end}s) from {audio_file}")
  209. audio, sr = sf.read(audio_file)
  210. start_sample = int(audio_start * sr)
  211. end_sample = int(audio_end * sr)
  212. audio_segment = audio[start_sample:end_sample]
  213. if len(audio_segment.shape) == 1:
  214. audio_segment = np.stack([audio_segment, audio_segment], axis=-1)
  215. # [batch, channels, samples]
  216. audio_tensor = torch.tensor(audio_segment.T, dtype=torch.float32).unsqueeze(0)
  217. print(f" Audio shape: {audio_tensor.shape}")
  218. # 3. Run instrumented forward pass
  219. print("\n[3/4] Running instrumented forward pass...")
  220. captured = {}
  221. with torch.no_grad():
  222. device = audio_tensor.device
  223. raw_audio = audio_tensor
  224. if raw_audio.ndim == 2:
  225. raw_audio = rearrange(raw_audio, "b t -> b 1 t")
  226. batch, channels, raw_audio_length = raw_audio.shape
  227. istft_length = raw_audio_length
  228. # STFT
  229. raw_audio_packed, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
  230. stft_window = model.stft_window_fn(device=device)
  231. stft_repr = torch.stft(
  232. raw_audio_packed,
  233. **model.stft_kwargs,
  234. window=stft_window,
  235. return_complex=True,
  236. )
  237. stft_repr = torch.view_as_real(stft_repr)
  238. # ===== CAPTURE: Raw STFT/ISTFT for C++ Verification =====
  239. # Unpack to [batch, channels, freq, time, 2]
  240. stft_raw_unpacked = unpack_one(
  241. stft_repr, batch_audio_channel_packed_shape, "* f t c"
  242. )
  243. captured["stft_raw"] = stft_raw_unpacked.clone()
  244. # Compute ISTFT directly on this raw STFT (Identity check)
  245. stft_complex = torch.view_as_complex(stft_repr)
  246. istft_check = torch.istft(
  247. stft_complex,
  248. **model.stft_kwargs,
  249. window=stft_window,
  250. return_complex=False,
  251. length=istft_length,
  252. )
  253. istft_check_unpacked = unpack_one(
  254. istft_check, batch_audio_channel_packed_shape, "* t"
  255. )
  256. captured["istft_raw"] = istft_check_unpacked.clone()
  257. # ========================================================
  258. stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
  259. stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
  260. # Frequency indexing
  261. if model_type == "mel_band":
  262. batch_arange = torch.arange(batch, device=device)[..., None]
  263. x = stft_repr[batch_arange, model.freq_indices]
  264. x = rearrange(x, "b f t c -> b t (f c)")
  265. else:
  266. # BS Roformer: Direct usage
  267. x = stft_repr
  268. # If stft_repr is complex (view_as_real result: [b, f, t, 2])
  269. # BS model expects: [b, f, t, 2] -> rearrange to [b, t, (f * 2)]
  270. # Wait, bs_roformer.py: x = rearrange(x, 'b f t c -> b t (f c)')
  271. x = rearrange(x, "b f t c -> b t (f c)")
  272. # ===== CAPTURE: BandSplit Input =====
  273. captured["band_split_in"] = x.clone()
  274. # BandSplit
  275. x = model.band_split(x)
  276. # ===== CAPTURE: After BandSplit (= Transformer Input) =====
  277. captured["after_band_split"] = x.clone()
  278. # Transformer Layers
  279. for layer_idx, (time_transformer, freq_transformer) in enumerate(model.layers):
  280. # Time Transformer
  281. x = rearrange(x, "b t f d -> b f t d")
  282. x, ps = pack([x], "* t d")
  283. x = time_transformer(x)
  284. (x,) = unpack(x, ps, "* t d")
  285. x = rearrange(x, "b f t d -> b t f d")
  286. # Freq Transformer
  287. x, ps = pack([x], "* f d")
  288. x = freq_transformer(x)
  289. (x,) = unpack(x, ps, "* f d")
  290. # BS Roformer: Apply global final_norm after all transformer layers
  291. if model_type == "bs" and hasattr(model, "final_norm"):
  292. x = model.final_norm(x)
  293. # ===== CAPTURE: Before Mask Estimator (= Transformer Output) =====
  294. captured["before_mask_est"] = x.clone()
  295. # Mask Estimator (just first one for testing)
  296. mask0 = model.mask_estimators[0](x)
  297. # ===== CAPTURE: Mask Estimator Output =====
  298. captured["mask_est0"] = mask0.clone()
  299. # Continue with full forward pass for output
  300. num_stems = len(model.mask_estimators)
  301. masks = torch.stack([fn(x) for fn in model.mask_estimators], dim=1)
  302. masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
  303. stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
  304. stft_repr = torch.view_as_complex(stft_repr)
  305. masks = torch.view_as_complex(masks)
  306. masks = masks.type(stft_repr.dtype)
  307. from einops import repeat
  308. if model_type == "mel_band":
  309. scatter_indices = repeat(
  310. model.freq_indices,
  311. "f -> b n f t",
  312. b=batch,
  313. n=num_stems,
  314. t=stft_repr.shape[-1],
  315. )
  316. stft_repr_expanded_stems = repeat(
  317. stft_repr, "b 1 ... -> b n ...", n=num_stems
  318. )
  319. masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
  320. 2, scatter_indices, masks
  321. )
  322. denom = repeat(model.num_bands_per_freq, "f -> (f r) 1", r=channels)
  323. masks_averaged = masks_summed / denom.clamp(min=1e-8)
  324. stft_repr = stft_repr * masks_averaged
  325. else:
  326. # BS Roformer: Direct mask application
  327. # masks shape: [b, n, f, t, c] (rearranged above)
  328. # stft_repr shape: [b, 1, f, t, c] (rearranged above)
  329. # BS model output masks are often [b, n, f, t] (complex/real?)
  330. # Wait, bs_roformer.py:
  331. # masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
  332. # masks = rearrange(masks, 'b n t (f c) -> b n f t c', c = 2)
  333. # x = x * masks.sum(dim=1) # summation over stems? No, output separate stems.
  334. # return x * masks
  335. # So here: stft_repr * masks is correct.
  336. stft_repr = stft_repr * masks
  337. # ISTFT
  338. if model_type == "mel_band":
  339. stft_repr = rearrange(
  340. stft_repr, "b n (f s) t -> (b n s) f t", s=model.audio_channels
  341. )
  342. else:
  343. # BS Roformer: stft_repr is [b, n, (Freq*Stereo), t] (complex)
  344. # Unpack stereo and flatten batch/stems/stereo for istft
  345. stft_repr = rearrange(
  346. stft_repr, "b n (f s) t -> (b n s) f t", s=model.audio_channels
  347. )
  348. if getattr(model, "zero_dc", False):
  349. # Zero out DC component
  350. stft_repr = stft_repr.clone()
  351. stft_repr[:, 0, :] = 0.0
  352. recon_audio = torch.istft(
  353. stft_repr,
  354. **model.stft_kwargs,
  355. window=stft_window,
  356. return_complex=False,
  357. length=istft_length,
  358. )
  359. recon_audio = rearrange(
  360. recon_audio,
  361. "(b n s) t -> b n s t",
  362. b=batch,
  363. s=model.audio_channels,
  364. n=num_stems,
  365. )
  366. if num_stems == 1:
  367. recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
  368. captured["output_audio"] = recon_audio.clone()
  369. else:
  370. # Capture Stem 0 for verification
  371. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  372. # Capture Stem 0 for verification
  373. captured["output_audio"] = recon_audio[:, 0, :, :].clone()
  374. # 4. Generate Chunking Debug Data
  375. generate_chunking_data(output_path, config)
  376. # 5. Save tensors
  377. print(f"\n[4/5] Saving test data to {output_dir}")
  378. # Input audio
  379. save_tensor(output_path, "input_audio", audio_tensor)
  380. # Captured tensors
  381. for name, tensor in captured.items():
  382. save_tensor(output_path, name, tensor)
  383. # Verify outputs match normal forward pass
  384. print("\n[Verification] Checking output matches model.forward()...")
  385. with torch.no_grad():
  386. baseline = model(audio_tensor)
  387. if hasattr(model, "num_stems") and model.num_stems > 1:
  388. baseline = baseline[:, 0, :, :]
  389. diff = (baseline - captured["output_audio"]).abs()
  390. max_diff = diff.max().item()
  391. if max_diff > 1e-6:
  392. print(f" ✗ FAILED: max_diff = {max_diff:.2e}")
  393. return 1
  394. else:
  395. print(f" ✓ PASSED: max_diff = {max_diff:.2e}")
  396. print("\n" + "=" * 70)
  397. print("Test data generation complete!")
  398. print(f" Output: {output_dir}/activations/")
  399. print(f" Files: {len(captured) + 1} tensors")
  400. print("=" * 70)
  401. return 0
  402. def main():
  403. parser = argparse.ArgumentParser(
  404. description="Generate test data for BSRoformer.cpp",
  405. formatter_class=argparse.RawDescriptionHelpFormatter,
  406. epilog=f"""
  407. Requirements:
  408. This script requires the original Mel-Band-Roformer-Vocal-Model repository.
  409. Clone it first:
  410. git clone {MODEL_REPO_URL}
  411. Then specify the path with --model-repo.
  412. Example:
  413. python generate_test_data.py \\
  414. --model-repo /path/to/Mel-Band-Roformer-Vocal-Model \\
  415. --audio test.wav \\
  416. --checkpoint model.ckpt \\
  417. --output test_data
  418. """,
  419. )
  420. parser.add_argument(
  421. "--model-repo",
  422. required=True,
  423. help=f"Path to Mel-Band-Roformer-Vocal-Model repository (clone from {MODEL_REPO_URL})",
  424. )
  425. parser.add_argument("--audio", required=True, help="Input audio file (WAV)")
  426. parser.add_argument(
  427. "--checkpoint", required=True, help="Model checkpoint file (.ckpt)"
  428. )
  429. parser.add_argument(
  430. "--config",
  431. help="Model config YAML file (default: <model-repo>/configs/config_vocals_mel_band_roformer.yaml)",
  432. )
  433. parser.add_argument(
  434. "--output", default="test_data", help="Output directory for test data"
  435. )
  436. parser.add_argument(
  437. "--start",
  438. type=float,
  439. default=2.0,
  440. help="Audio start time in seconds (default: 2.0)",
  441. )
  442. parser.add_argument(
  443. "--end",
  444. type=float,
  445. default=5.0,
  446. help="Audio end time in seconds (default: 5.0)",
  447. )
  448. args = parser.parse_args()
  449. # Resolve paths
  450. model_repo_path = Path(args.model_repo).resolve()
  451. audio_path = Path(args.audio).resolve()
  452. checkpoint_path = Path(args.checkpoint).resolve()
  453. output_path = Path(args.output).resolve()
  454. # Config defaults to model-repo/configs/...
  455. if args.config:
  456. config_path = Path(args.config).resolve()
  457. else:
  458. config_path = (
  459. model_repo_path / "configs" / "config_vocals_mel_band_roformer.yaml"
  460. )
  461. # Validate paths
  462. if not audio_path.exists():
  463. print(f"Error: Audio file not found: {audio_path}")
  464. return 1
  465. if not checkpoint_path.exists():
  466. print(f"Error: Checkpoint not found: {checkpoint_path}")
  467. return 1
  468. if not config_path.exists():
  469. print(f"Error: Config not found: {config_path}")
  470. return 1
  471. return generate_test_data(
  472. str(model_repo_path),
  473. str(audio_path),
  474. str(checkpoint_path),
  475. str(config_path),
  476. str(output_path),
  477. args.start,
  478. args.end,
  479. )
  480. if __name__ == "__main__":
  481. sys.exit(main())