Update model.py
Browse files
model.py
CHANGED
@@ -57,69 +57,54 @@ def load_model():
|
|
57 |
|
58 |
@torch.no_grad()
|
59 |
def invert_audio(
|
60 |
-
model, processor,
|
61 |
normalize=True, flip_input=True, flip_output=False):
|
62 |
|
63 |
model.config.normalize = normalize
|
64 |
|
65 |
-
|
66 |
-
if
|
67 |
-
|
68 |
|
69 |
-
#
|
70 |
if flip_input:
|
71 |
-
|
72 |
-
|
73 |
-
#
|
74 |
-
inputs_1 = processor(raw_audio=
|
75 |
inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
|
76 |
inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
|
77 |
|
78 |
-
#
|
79 |
print("Encoding...")
|
80 |
encoder_outputs_1 = model.encode(
|
81 |
inputs_1["input_values"],
|
82 |
inputs_1["padding_mask"],
|
83 |
bandwidth=max(model.config.target_bandwidths))
|
84 |
|
85 |
-
# EMBEDDINGS (no quantized):
|
86 |
-
# encoder_outputs.audio_codes.shape
|
87 |
-
# [216, 1, 128, 150]
|
88 |
-
|
89 |
avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
|
90 |
-
# [1, 1, 128, 1]
|
91 |
avg_repeat = avg.repeat(
|
92 |
encoder_outputs_1.audio_codes.shape[0],
|
93 |
encoder_outputs_1.audio_codes.shape[1],
|
94 |
1,
|
95 |
encoder_outputs_1.audio_codes.shape[3])
|
96 |
-
# [216, 1, 128, 150]
|
97 |
diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
|
98 |
-
|
99 |
-
# TODO: power factor calculations kinda useless if we keep the factor one???
|
100 |
POWER_FACTOR = 1
|
101 |
max_abs_diff = torch.max(torch.abs(diff_repeat))
|
102 |
diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
|
103 |
latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
|
104 |
|
105 |
-
#
|
106 |
latents = latents * -1.0
|
107 |
|
108 |
print("Decoding...")
|
109 |
audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
|
110 |
|
111 |
-
# [1, 2, 10264800]
|
112 |
if flip_output:
|
113 |
audio_values = torch.flip(audio_values, dims=(2,))
|
114 |
|
115 |
-
|
116 |
decoded_wav = audio_values.squeeze(0).to("cpu")
|
117 |
|
118 |
-
|
119 |
-
out_path_ = audio_write(
|
120 |
-
out_path,
|
121 |
-
sample_rate=MODEL_SAMPLING_RATE,
|
122 |
-
wav=decoded_wav,
|
123 |
-
normalize=False)
|
124 |
-
|
125 |
-
return out_path_
|
|
|
57 |
|
58 |
@torch.no_grad()
|
59 |
def invert_audio(
|
60 |
+
model, processor, input_audio, sampling_rate,
|
61 |
normalize=True, flip_input=True, flip_output=False):
|
62 |
|
63 |
model.config.normalize = normalize
|
64 |
|
65 |
+
# Check and resample the input audio if necessary
|
66 |
+
if sampling_rate != MODEL_SAMPLING_RATE:
|
67 |
+
input_audio = julius.resample_frac(input_audio, sampling_rate, MODEL_SAMPLING_RATE)
|
68 |
|
69 |
+
# Flip the audio if required
|
70 |
if flip_input:
|
71 |
+
input_audio = torch.flip(input_audio, dims=(1,))
|
72 |
+
|
73 |
+
# Pre-process the inputs
|
74 |
+
inputs_1 = processor(raw_audio=input_audio, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
|
75 |
inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
|
76 |
inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
|
77 |
|
78 |
+
# Explicitly encode then decode the audio inputs
|
79 |
print("Encoding...")
|
80 |
encoder_outputs_1 = model.encode(
|
81 |
inputs_1["input_values"],
|
82 |
inputs_1["padding_mask"],
|
83 |
bandwidth=max(model.config.target_bandwidths))
|
84 |
|
|
|
|
|
|
|
|
|
85 |
avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
|
|
|
86 |
avg_repeat = avg.repeat(
|
87 |
encoder_outputs_1.audio_codes.shape[0],
|
88 |
encoder_outputs_1.audio_codes.shape[1],
|
89 |
1,
|
90 |
encoder_outputs_1.audio_codes.shape[3])
|
|
|
91 |
diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
|
92 |
+
|
|
|
93 |
POWER_FACTOR = 1
|
94 |
max_abs_diff = torch.max(torch.abs(diff_repeat))
|
95 |
diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
|
96 |
latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
|
97 |
|
98 |
+
# Inversion of difference
|
99 |
latents = latents * -1.0
|
100 |
|
101 |
print("Decoding...")
|
102 |
audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
|
103 |
|
|
|
104 |
if flip_output:
|
105 |
audio_values = torch.flip(audio_values, dims=(2,))
|
106 |
|
107 |
+
# Return the decoded audio tensor (or NumPy array, based on your audio_write function)
|
108 |
decoded_wav = audio_values.squeeze(0).to("cpu")
|
109 |
|
110 |
+
return decoded_wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|