cmagganas commited on
Commit
9017fab
1 Parent(s): 4ecaedb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -29
model.py CHANGED
@@ -57,69 +57,54 @@ def load_model():
57
 
58
  @torch.no_grad()
59
  def invert_audio(
60
- model, processor, input_audio_path, out_path,
61
  normalize=True, flip_input=True, flip_output=False):
62
 
63
  model.config.normalize = normalize
64
 
65
- audio_sample_1, sampling_rate_1 = audio_read(input_audio_path)
66
- if sampling_rate_1 != MODEL_SAMPLING_RATE:
67
- audio_sample_1 = julius.resample_frac(audio_sample_1, sampling_rate_1, MODEL_SAMPLING_RATE)
68
 
69
- # audio_sample [2, 9399305]
70
  if flip_input:
71
- audio_sample_1 = torch.flip(audio_sample_1, dims=(1,))
72
-
73
- # pre-process the inputs
74
- inputs_1 = processor(raw_audio=audio_sample_1, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
75
  inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
76
  inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
77
 
78
- # explicitly encode then decode the audio inputs
79
  print("Encoding...")
80
  encoder_outputs_1 = model.encode(
81
  inputs_1["input_values"],
82
  inputs_1["padding_mask"],
83
  bandwidth=max(model.config.target_bandwidths))
84
 
85
- # EMBEDDINGS (no quantized):
86
- # encoder_outputs.audio_codes.shape
87
- # [216, 1, 128, 150]
88
-
89
  avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
90
- # [1, 1, 128, 1]
91
  avg_repeat = avg.repeat(
92
  encoder_outputs_1.audio_codes.shape[0],
93
  encoder_outputs_1.audio_codes.shape[1],
94
  1,
95
  encoder_outputs_1.audio_codes.shape[3])
96
- # [216, 1, 128, 150]
97
  diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
98
-
99
- # TODO: power factor calculations kinda useless if we keep the factor one???
100
  POWER_FACTOR = 1
101
  max_abs_diff = torch.max(torch.abs(diff_repeat))
102
  diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
103
  latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
104
 
105
- # difference inversion done here!
106
  latents = latents * -1.0
107
 
108
  print("Decoding...")
109
  audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
110
 
111
- # [1, 2, 10264800]
112
  if flip_output:
113
  audio_values = torch.flip(audio_values, dims=(2,))
114
 
115
- output_dir = "/home/romainpaulusisep_gmail_com/data/outputs"
116
  decoded_wav = audio_values.squeeze(0).to("cpu")
117
 
118
- print("Saving output file...")
119
- out_path_ = audio_write(
120
- out_path,
121
- sample_rate=MODEL_SAMPLING_RATE,
122
- wav=decoded_wav,
123
- normalize=False)
124
-
125
- return out_path_
 
57
 
58
  @torch.no_grad()
59
  def invert_audio(
60
+ model, processor, input_audio, sampling_rate,
61
  normalize=True, flip_input=True, flip_output=False):
62
 
63
  model.config.normalize = normalize
64
 
65
+ # Check and resample the input audio if necessary
66
+ if sampling_rate != MODEL_SAMPLING_RATE:
67
+ input_audio = julius.resample_frac(input_audio, sampling_rate, MODEL_SAMPLING_RATE)
68
 
69
+ # Flip the audio if required
70
  if flip_input:
71
+ input_audio = torch.flip(input_audio, dims=(1,))
72
+
73
+ # Pre-process the inputs
74
+ inputs_1 = processor(raw_audio=input_audio, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
75
  inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
76
  inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
77
 
78
+ # Explicitly encode then decode the audio inputs
79
  print("Encoding...")
80
  encoder_outputs_1 = model.encode(
81
  inputs_1["input_values"],
82
  inputs_1["padding_mask"],
83
  bandwidth=max(model.config.target_bandwidths))
84
 
 
 
 
 
85
  avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
 
86
  avg_repeat = avg.repeat(
87
  encoder_outputs_1.audio_codes.shape[0],
88
  encoder_outputs_1.audio_codes.shape[1],
89
  1,
90
  encoder_outputs_1.audio_codes.shape[3])
 
91
  diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
92
+
 
93
  POWER_FACTOR = 1
94
  max_abs_diff = torch.max(torch.abs(diff_repeat))
95
  diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
96
  latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
97
 
98
+ # Inversion of difference
99
  latents = latents * -1.0
100
 
101
  print("Decoding...")
102
  audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
103
 
 
104
  if flip_output:
105
  audio_values = torch.flip(audio_values, dims=(2,))
106
 
107
+ # Return the decoded audio tensor (or NumPy array, based on your audio_write function)
108
  decoded_wav = audio_values.squeeze(0).to("cpu")
109
 
110
+ return decoded_wav