gpt-omni commited on
Commit
008f4a1
·
1 Parent(s): e6f5492
Files changed (3) hide show
  1. inference.py +1 -1
  2. snac_utils.py +0 -143
  3. utils/snac_utils.py +11 -8
inference.py CHANGED
@@ -494,7 +494,7 @@ class OmniInference:
494
  if current_index == nums_generate:
495
  current_index = 0
496
  snac = get_snac(list_output, index, nums_generate)
497
- audio_stream = generate_audio_data(snac, self.snacmodel)
498
  yield audio_stream
499
 
500
  input_pos = input_pos.add_(1)
 
494
  if current_index == nums_generate:
495
  current_index = 0
496
  snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
498
  yield audio_stream
499
 
500
  input_pos = input_pos.add_(1)
snac_utils.py DELETED
@@ -1,143 +0,0 @@
1
- import torch
2
- import time
3
- import numpy as np
4
-
5
-
6
- class SnacConfig:
7
- audio_vocab_size = 4096
8
- padded_vocab_size = 4160
9
- end_of_audio = 4097
10
-
11
-
12
- snac_config = SnacConfig()
13
-
14
-
15
- def get_time_str():
16
- time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
- return time_str
18
-
19
-
20
- def layershift(input_id, layer, stride=4160, shift=152000):
21
- return input_id + shift + layer * stride
22
-
23
-
24
- def generate_audio_data(snac_tokens, snacmodel):
25
- audio = reconstruct_tensors(snac_tokens)
26
- with torch.inference_mode():
27
- audio_hat = snacmodel.decode(audio)
28
- audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
- audio_data = audio_data.astype(np.int16)
30
- audio_data = audio_data.tobytes()
31
- return audio_data
32
-
33
-
34
- def get_snac(list_output, index, nums_generate):
35
-
36
- snac = []
37
- start = index
38
- for i in range(nums_generate):
39
- snac.append("#")
40
- for j in range(7):
41
- snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
- return snac
43
-
44
-
45
- def reconscruct_snac(output_list):
46
- if len(output_list) == 8:
47
- output_list = output_list[:-1]
48
- output = []
49
- for i in range(7):
50
- output_list[i] = output_list[i][i + 1 :]
51
- for i in range(len(output_list[-1])):
52
- output.append("#")
53
- for j in range(7):
54
- output.append(output_list[j][i])
55
- return output
56
-
57
-
58
- def reconstruct_tensors(flattened_output):
59
- """Reconstructs the list of tensors from the flattened output."""
60
-
61
- def count_elements_between_hashes(lst):
62
- try:
63
- # Find the index of the first '#'
64
- first_index = lst.index("#")
65
- # Find the index of the second '#' after the first
66
- second_index = lst.index("#", first_index + 1)
67
- # Count the elements between the two indices
68
- return second_index - first_index - 1
69
- except ValueError:
70
- # Handle the case where there aren't enough '#' symbols
71
- return "List does not contain two '#' symbols"
72
-
73
- def remove_elements_before_hash(flattened_list):
74
- try:
75
- # Find the index of the first '#'
76
- first_hash_index = flattened_list.index("#")
77
- # Return the list starting from the first '#'
78
- return flattened_list[first_hash_index:]
79
- except ValueError:
80
- # Handle the case where there is no '#'
81
- return "List does not contain the symbol '#'"
82
-
83
- def list_to_torch_tensor(tensor1):
84
- # Convert the list to a torch tensor
85
- tensor = torch.tensor(tensor1)
86
- # Reshape the tensor to have size (1, n)
87
- tensor = tensor.unsqueeze(0)
88
- return tensor
89
-
90
- flattened_output = remove_elements_before_hash(flattened_output)
91
- codes = []
92
- tensor1 = []
93
- tensor2 = []
94
- tensor3 = []
95
- tensor4 = []
96
-
97
- n_tensors = count_elements_between_hashes(flattened_output)
98
- if n_tensors == 7:
99
- for i in range(0, len(flattened_output), 8):
100
-
101
- tensor1.append(flattened_output[i + 1])
102
- tensor2.append(flattened_output[i + 2])
103
- tensor3.append(flattened_output[i + 3])
104
- tensor3.append(flattened_output[i + 4])
105
-
106
- tensor2.append(flattened_output[i + 5])
107
- tensor3.append(flattened_output[i + 6])
108
- tensor3.append(flattened_output[i + 7])
109
- codes = [
110
- list_to_torch_tensor(tensor1).cuda(),
111
- list_to_torch_tensor(tensor2).cuda(),
112
- list_to_torch_tensor(tensor3).cuda(),
113
- ]
114
-
115
- if n_tensors == 15:
116
- for i in range(0, len(flattened_output), 16):
117
-
118
- tensor1.append(flattened_output[i + 1])
119
- tensor2.append(flattened_output[i + 2])
120
- tensor3.append(flattened_output[i + 3])
121
- tensor4.append(flattened_output[i + 4])
122
- tensor4.append(flattened_output[i + 5])
123
- tensor3.append(flattened_output[i + 6])
124
- tensor4.append(flattened_output[i + 7])
125
- tensor4.append(flattened_output[i + 8])
126
-
127
- tensor2.append(flattened_output[i + 9])
128
- tensor3.append(flattened_output[i + 10])
129
- tensor4.append(flattened_output[i + 11])
130
- tensor4.append(flattened_output[i + 12])
131
- tensor3.append(flattened_output[i + 13])
132
- tensor4.append(flattened_output[i + 14])
133
- tensor4.append(flattened_output[i + 15])
134
-
135
- codes = [
136
- list_to_torch_tensor(tensor1).cuda(),
137
- list_to_torch_tensor(tensor2).cuda(),
138
- list_to_torch_tensor(tensor3).cuda(),
139
- list_to_torch_tensor(tensor4).cuda(),
140
- ]
141
-
142
- return codes
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/snac_utils.py CHANGED
@@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
55
  return output
56
 
57
 
58
- def reconstruct_tensors(flattened_output):
59
  """Reconstructs the list of tensors from the flattened output."""
60
 
 
 
 
61
  def count_elements_between_hashes(lst):
62
  try:
63
  # Find the index of the first '#'
@@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
107
  tensor3.append(flattened_output[i + 6])
108
  tensor3.append(flattened_output[i + 7])
109
  codes = [
110
- list_to_torch_tensor(tensor1).cuda(),
111
- list_to_torch_tensor(tensor2).cuda(),
112
- list_to_torch_tensor(tensor3).cuda(),
113
  ]
114
 
115
  if n_tensors == 15:
@@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
133
  tensor4.append(flattened_output[i + 15])
134
 
135
  codes = [
136
- list_to_torch_tensor(tensor1).cuda(),
137
- list_to_torch_tensor(tensor2).cuda(),
138
- list_to_torch_tensor(tensor3).cuda(),
139
- list_to_torch_tensor(tensor4).cuda(),
140
  ]
141
 
142
  return codes
 
55
  return output
56
 
57
 
58
+ def reconstruct_tensors(flattened_output, device=None):
59
  """Reconstructs the list of tensors from the flattened output."""
60
 
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
  def count_elements_between_hashes(lst):
65
  try:
66
  # Find the index of the first '#'
 
110
  tensor3.append(flattened_output[i + 6])
111
  tensor3.append(flattened_output[i + 7])
112
  codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
  ]
117
 
118
  if n_tensors == 15:
 
136
  tensor4.append(flattened_output[i + 15])
137
 
138
  codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
  ]
144
 
145
  return codes