separate.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import soundfile as sf
  2. import torch
  3. import os
  4. import librosa
  5. import numpy as np
  6. import onnxruntime as ort
  7. from pathlib import Path
  8. from argparse import ArgumentParser
  9. from tqdm import tqdm
  10. class ConvTDFNet:
  11. def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
  12. super(ConvTDFNet, self).__init__()
  13. self.dim_c = 4
  14. self.dim_f = dim_f
  15. self.dim_t = 2**dim_t
  16. self.n_fft = n_fft
  17. self.hop = hop
  18. self.n_bins = self.n_fft // 2 + 1
  19. self.chunk_size = hop * (self.dim_t - 1)
  20. self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
  21. self.target_name = target_name
  22. out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
  23. self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
  24. self.n = L // 2
  25. def stft(self, x):
  26. x = x.reshape([-1, self.chunk_size])
  27. x = torch.stft(
  28. x,
  29. n_fft=self.n_fft,
  30. hop_length=self.hop,
  31. window=self.window,
  32. center=True,
  33. return_complex=True,
  34. )
  35. x = torch.view_as_real(x)
  36. x = x.permute([0, 3, 1, 2])
  37. x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
  38. [-1, self.dim_c, self.n_bins, self.dim_t]
  39. )
  40. return x[:, :, : self.dim_f]
  41. # Inversed Short-time Fourier transform (STFT).
  42. def istft(self, x, freq_pad=None):
  43. freq_pad = (
  44. self.freq_pad.repeat([x.shape[0], 1, 1, 1])
  45. if freq_pad is None
  46. else freq_pad
  47. )
  48. x = torch.cat([x, freq_pad], -2)
  49. c = 4 * 2 if self.target_name == "*" else 2
  50. x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
  51. [-1, 2, self.n_bins, self.dim_t]
  52. )
  53. x = x.permute([0, 2, 3, 1])
  54. x = x.contiguous()
  55. x = torch.view_as_complex(x)
  56. x = torch.istft(
  57. x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
  58. )
  59. return x.reshape([-1, c, self.chunk_size])
  60. class Predictor:
  61. def __init__(self, args):
  62. self.args = args
  63. self.model_ = ConvTDFNet(
  64. target_name="vocals",
  65. L=11,
  66. dim_f=args["dim_f"],
  67. dim_t=args["dim_t"],
  68. n_fft=args["n_fft"]
  69. )
  70. if torch.cuda.is_available():
  71. self.model = ort.InferenceSession(args['model_path'], providers=['CUDAExecutionProvider'])
  72. else:
  73. self.model = ort.InferenceSession(args['model_path'], providers=['CPUExecutionProvider'])
  74. def demix(self, mix):
  75. samples = mix.shape[-1]
  76. margin = self.args["margin"]
  77. chunk_size = self.args["chunks"] * 44100
  78. assert not margin == 0, "margin cannot be zero!"
  79. if margin > chunk_size:
  80. margin = chunk_size
  81. segmented_mix = {}
  82. if self.args["chunks"] == 0 or samples < chunk_size:
  83. chunk_size = samples
  84. counter = -1
  85. for skip in range(0, samples, chunk_size):
  86. counter += 1
  87. s_margin = 0 if counter == 0 else margin
  88. end = min(skip + chunk_size + margin, samples)
  89. start = skip - s_margin
  90. segmented_mix[skip] = mix[:, start:end].copy()
  91. if end == samples:
  92. break
  93. sources = self.demix_base(segmented_mix, margin_size=margin)
  94. return sources
  95. def demix_base(self, mixes, margin_size):
  96. chunked_sources = []
  97. progress_bar = tqdm(total=len(mixes))
  98. progress_bar.set_description("Processing")
  99. for mix in mixes:
  100. cmix = mixes[mix]
  101. sources = []
  102. n_sample = cmix.shape[1]
  103. model = self.model_
  104. trim = model.n_fft // 2
  105. gen_size = model.chunk_size - 2 * trim
  106. pad = gen_size - n_sample % gen_size
  107. mix_p = np.concatenate(
  108. (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
  109. )
  110. mix_waves = []
  111. i = 0
  112. while i < n_sample + pad:
  113. waves = np.array(mix_p[:, i : i + model.chunk_size])
  114. mix_waves.append(waves)
  115. i += gen_size
  116. mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)
  117. with torch.no_grad():
  118. _ort = self.model
  119. spek = model.stft(mix_waves)
  120. if self.args["denoise"]:
  121. spec_pred = (
  122. -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
  123. + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
  124. )
  125. tar_waves = model.istft(torch.tensor(spec_pred))
  126. else:
  127. tar_waves = model.istft(
  128. torch.tensor(_ort.run(None, {"input": spek.cpu().numpy() })[0])
  129. )
  130. tar_signal = (
  131. tar_waves[:, :, trim:-trim]
  132. .transpose(0, 1)
  133. .reshape(2, -1)
  134. .numpy()[:, :-pad]
  135. )
  136. start = 0 if mix == 0 else margin_size
  137. end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
  138. if margin_size == 0:
  139. end = None
  140. sources.append(tar_signal[:, start:end])
  141. progress_bar.update(1)
  142. chunked_sources.append(sources)
  143. _sources = np.concatenate(chunked_sources, axis=-1)
  144. progress_bar.close()
  145. return _sources
  146. def predict(self, file_path):
  147. mix, rate = librosa.load(file_path, mono=False, sr=44100)
  148. if mix.ndim == 1:
  149. mix = np.asfortranarray([mix, mix])
  150. mix = mix.T
  151. sources = self.demix(mix.T)
  152. opt = sources[0].T
  153. return (mix - opt, opt, rate)
  154. def main():
  155. parser = ArgumentParser()
  156. parser.add_argument("files", nargs="+", type=Path, default=[], help="Source audio path")
  157. parser.add_argument("-o", "--output", type=Path, default=Path("separated"), help="Output folder")
  158. parser.add_argument("-m", "--model_path", type=Path, help="MDX Net ONNX Model path")
  159. parser.add_argument("-d", "--no-denoise", dest="denoise", action="store_false", default=True, help="Disable denoising")
  160. parser.add_argument("-M", "--margin", type=int, default=44100, help="Margin")
  161. parser.add_argument("-c", "--chunks", type=int, default=15, help="Chunk size")
  162. parser.add_argument("-F", "--n_fft", type=int, default=6144)
  163. parser.add_argument("-t", "--dim_t", type=int, default=8)
  164. parser.add_argument("-f", "--dim_f", type=int, default=2048)
  165. args = parser.parse_args()
  166. dict_args = vars(args)
  167. os.makedirs(args.output, exist_ok=True)
  168. for file_path in args.files:
  169. predictor = Predictor(args=dict_args)
  170. vocals, no_vocals, sampling_rate = predictor.predict(file_path)
  171. filename = os.path.splitext(os.path.split(file_path)[-1])[0]
  172. sf.write(os.path.join(args.output, filename+"_no_vocals.wav"), no_vocals, sampling_rate)
  173. sf.write(os.path.join(args.output, filename+"_vocals.wav"), vocals, sampling_rate)
  174. if __name__ == "__main__":
  175. main()