Spaces:
Sleeping
Sleeping
update loss
Browse files- app.py +23 -4
- inference.py +17 -9
- modules/__pycache__/loss.cpython-311.pyc +0 -0
- modules/loss.py +49 -20
app.py
CHANGED
@@ -94,7 +94,8 @@ def process_audio(input_audio, reference_audio):
|
|
94 |
|
95 |
return (sr, output_audio), param_output, (sr, normalized_input)
|
96 |
|
97 |
-
def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights):
|
|
|
98 |
if ito_reference_audio is None:
|
99 |
ito_reference_audio = reference_audio
|
100 |
af_weights = [float(w.strip()) for w in af_weights.split(',')]
|
@@ -104,7 +105,10 @@ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, op
|
|
104 |
'learning_rate': learning_rate,
|
105 |
'num_steps': num_steps,
|
106 |
'af_weights': af_weights,
|
107 |
-
'sample_rate': args.sample_rate
|
|
|
|
|
|
|
108 |
}
|
109 |
|
110 |
input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
|
@@ -219,7 +223,22 @@ with gr.Blocks() as demo:
|
|
219 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
220 |
learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
|
221 |
af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
ito_button = gr.Button("Perform ITO")
|
224 |
|
225 |
with gr.Row():
|
@@ -243,7 +262,7 @@ with gr.Blocks() as demo:
|
|
243 |
|
244 |
ito_button.click(
|
245 |
perform_ito,
|
246 |
-
inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights],
|
247 |
outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
|
248 |
).then(
|
249 |
update_ito_output,
|
|
|
94 |
|
95 |
return (sr, output_audio), param_output, (sr, normalized_input)
|
96 |
|
97 |
+
# def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights):
|
98 |
+
def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt):
|
99 |
if ito_reference_audio is None:
|
100 |
ito_reference_audio = reference_audio
|
101 |
af_weights = [float(w.strip()) for w in af_weights.split(',')]
|
|
|
105 |
'learning_rate': learning_rate,
|
106 |
'num_steps': num_steps,
|
107 |
'af_weights': af_weights,
|
108 |
+
'sample_rate': args.sample_rate,
|
109 |
+
'loss_function': loss_function,
|
110 |
+
'clap_target_type': clap_target_type,
|
111 |
+
'clap_text_prompt': clap_text_prompt
|
112 |
}
|
113 |
|
114 |
input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
|
|
|
223 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
224 |
learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
|
225 |
af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
|
226 |
+
loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
|
227 |
+
clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
|
228 |
+
clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
|
229 |
+
|
230 |
+
def update_clap_options(loss_function):
|
231 |
+
if loss_function == "CLAPFeatureLoss":
|
232 |
+
return gr.update(visible=True), gr.update(visible=True)
|
233 |
+
else:
|
234 |
+
return gr.update(visible=False), gr.update(visible=False)
|
235 |
+
|
236 |
+
loss_function.change(
|
237 |
+
update_clap_options,
|
238 |
+
inputs=[loss_function],
|
239 |
+
outputs=[clap_target_type, clap_text_prompt]
|
240 |
+
)
|
241 |
+
|
242 |
ito_button = gr.Button("Perform ITO")
|
243 |
|
244 |
with gr.Row():
|
|
|
262 |
|
263 |
ito_button.click(
|
264 |
perform_ito,
|
265 |
+
inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt],
|
266 |
outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
|
267 |
).then(
|
268 |
update_ito_output,
|
inference.py
CHANGED
@@ -34,7 +34,14 @@ class MasteringStyleTransfer:
|
|
34 |
self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \
|
35 |
STEMS=['mixture'], \
|
36 |
EFFECTS=['eq', 'imager', 'loudness'])
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def load_effects_encoder(self):
|
40 |
effects_encoder = Effects_Encoder(self.args.cfg_enc)
|
@@ -70,13 +77,6 @@ class MasteringStyleTransfer:
|
|
70 |
fit_embedding = torch.nn.Parameter(initial_reference_feature)
|
71 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
72 |
|
73 |
-
af_loss = AudioFeatureLoss(
|
74 |
-
weights=ito_config['af_weights'],
|
75 |
-
sample_rate=ito_config['sample_rate'],
|
76 |
-
stem_separation=False,
|
77 |
-
use_clap=False
|
78 |
-
)
|
79 |
-
|
80 |
min_loss = float('inf')
|
81 |
min_loss_step = 0
|
82 |
all_results = []
|
@@ -87,7 +87,15 @@ class MasteringStyleTransfer:
|
|
87 |
output_audio = self.mastering_converter(input_tensor, fit_embedding)
|
88 |
current_params = self.mastering_converter.get_last_predicted_params()
|
89 |
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
total_loss = sum(losses.values())
|
92 |
|
93 |
if total_loss < min_loss:
|
|
|
34 |
self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \
|
35 |
STEMS=['mixture'], \
|
36 |
EFFECTS=['eq', 'imager', 'loudness'])
|
37 |
+
# Loss functions
|
38 |
+
self.af_loss = AudioFeatureLoss(
|
39 |
+
weights=ito_config['af_weights'],
|
40 |
+
sample_rate=ito_config['sample_rate'],
|
41 |
+
stem_separation=False,
|
42 |
+
use_clap=False
|
43 |
+
)
|
44 |
+
self.clap_loss = CLAPFeatureLoss(distance_fn='cosine')
|
45 |
|
46 |
def load_effects_encoder(self):
|
47 |
effects_encoder = Effects_Encoder(self.args.cfg_enc)
|
|
|
77 |
fit_embedding = torch.nn.Parameter(initial_reference_feature)
|
78 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
min_loss = float('inf')
|
81 |
min_loss_step = 0
|
82 |
all_results = []
|
|
|
87 |
output_audio = self.mastering_converter(input_tensor, fit_embedding)
|
88 |
current_params = self.mastering_converter.get_last_predicted_params()
|
89 |
|
90 |
+
# Compute loss
|
91 |
+
if ito_config['loss_function'] == 'AudioFeatureLoss':
|
92 |
+
losses = self.af_loss(output_audio, reference_tensor)
|
93 |
+
elif ito_config['loss_function'] == 'CLAPFeatureLoss':
|
94 |
+
if ito_config['clap_target_type'] == 'Audio':
|
95 |
+
target = ito_reference_tensor
|
96 |
+
else:
|
97 |
+
target = ito_config['clap_text_prompt']
|
98 |
+
losses = self.clap_loss(est_targets, target, self.args.sample_rate)
|
99 |
total_loss = sum(losses.values())
|
100 |
|
101 |
if total_loss < min_loss:
|
modules/__pycache__/loss.cpython-311.pyc
CHANGED
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
|
|
modules/loss.py
CHANGED
@@ -196,36 +196,50 @@ class CLAPFeatureLoss(nn.Module):
|
|
196 |
else:
|
197 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
198 |
|
199 |
-
def forward(self, input_audio,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
# Ensure input is in the correct shape (N, C, T)
|
201 |
-
if
|
202 |
-
|
203 |
-
if target_audio.dim() == 2:
|
204 |
-
target_audio = target_audio.unsqueeze(1)
|
205 |
|
206 |
# Convert to mono if stereo
|
207 |
-
if
|
208 |
-
|
209 |
-
if target_audio.shape[1] > 1:
|
210 |
-
target_audio = target_audio.mean(dim=1, keepdim=True)
|
211 |
|
212 |
# Resample if necessary
|
213 |
if sample_rate != self.target_sample_rate:
|
214 |
-
|
215 |
-
target_audio = self.resample(target_audio, sample_rate)
|
216 |
|
217 |
# Quantize audio data
|
218 |
-
|
219 |
-
target_audio = self.quantize(target_audio)
|
220 |
|
221 |
# Get CLAP embeddings
|
222 |
-
|
223 |
-
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
229 |
|
230 |
def quantize(self, audio):
|
231 |
audio = audio.squeeze(1) # Remove channel dimension
|
@@ -490,4 +504,19 @@ class AudioFeatureLoss(torch.nn.Module):
|
|
490 |
val = torch.nn.functional.mse_loss(input_transform, target_transform)
|
491 |
losses[key] = weight * val * self.source_weights[stem_idx]
|
492 |
|
493 |
-
return losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
else:
|
197 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
198 |
|
199 |
+
def forward(self, input_audio, target, sample_rate):
|
200 |
+
# Process input audio
|
201 |
+
input_embed = self.process_audio(input_audio, sample_rate)
|
202 |
+
|
203 |
+
# Process target (audio or text)
|
204 |
+
if isinstance(target, torch.Tensor):
|
205 |
+
target_embed = self.process_audio(target, sample_rate)
|
206 |
+
elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
207 |
+
target_embed = self.process_text(target)
|
208 |
+
else:
|
209 |
+
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
210 |
+
|
211 |
+
# Compute loss using the specified distance function
|
212 |
+
loss = self.compute_distance(input_embed, target_embed)
|
213 |
+
|
214 |
+
return loss
|
215 |
+
|
216 |
+
def process_audio(self, audio, sample_rate):
|
217 |
# Ensure input is in the correct shape (N, C, T)
|
218 |
+
if audio.dim() == 2:
|
219 |
+
audio = audio.unsqueeze(1)
|
|
|
|
|
220 |
|
221 |
# Convert to mono if stereo
|
222 |
+
if audio.shape[1] > 1:
|
223 |
+
audio = audio.mean(dim=1, keepdim=True)
|
|
|
|
|
224 |
|
225 |
# Resample if necessary
|
226 |
if sample_rate != self.target_sample_rate:
|
227 |
+
audio = self.resample(audio, sample_rate)
|
|
|
228 |
|
229 |
# Quantize audio data
|
230 |
+
audio = self.quantize(audio)
|
|
|
231 |
|
232 |
# Get CLAP embeddings
|
233 |
+
embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
234 |
+
return embed
|
235 |
|
236 |
+
def process_text(self, text):
|
237 |
+
# Get CLAP embeddings for text
|
238 |
+
# ensure input is a list of strings
|
239 |
+
if not isinstance(text, list):
|
240 |
+
text = [text]
|
241 |
+
embed = self.model.get_text_embedding(text, use_tensor=True)
|
242 |
+
return embed
|
243 |
|
244 |
def quantize(self, audio):
|
245 |
audio = audio.squeeze(1) # Remove channel dimension
|
|
|
504 |
val = torch.nn.functional.mse_loss(input_transform, target_transform)
|
505 |
losses[key] = weight * val * self.source_weights[stem_idx]
|
506 |
|
507 |
+
return losses
|
508 |
+
|
509 |
+
|
510 |
+
|
511 |
+
if __name__ == "__main__":
|
512 |
+
clap_loss = CLAPFeatureLoss(distance_fn='cosine')
|
513 |
+
|
514 |
+
input_audio = torch.randn(1, 2, 44100)
|
515 |
+
target_audio = torch.randn(1, 2, 44100)
|
516 |
+
target_text = "This is a test"
|
517 |
+
sample_rate = 44100
|
518 |
+
loss = clap_loss(input_audio, target_audio, sample_rate)
|
519 |
+
print(loss)
|
520 |
+
loss = clap_loss(input_audio, target_text, sample_rate)
|
521 |
+
print(loss)
|
522 |
+
|