Update README.md
Browse files
README.md
CHANGED
@@ -5,7 +5,6 @@ base_model:
|
|
5 |
|
6 |
# Sample Inference Script
|
7 |
```py
|
8 |
-
import re
|
9 |
from argparse import ArgumentParser
|
10 |
|
11 |
import torch
|
@@ -35,8 +34,8 @@ parser.add_argument("-a", "--audio", default="")
|
|
35 |
parser.add_argument("-t", "--transcript", default="")
|
36 |
parser.add_argument("-o", "--output", default="output.wav")
|
37 |
parser.add_argument("-d", "--debug", action="store_true")
|
38 |
-
parser.add_argument("--max_seq_len", type=int, default=2048)
|
39 |
parser.add_argument("--sample_rate", type=int, default=16000)
|
|
|
40 |
parser.add_argument("--temperature", type=float, default=0.8)
|
41 |
parser.add_argument("--top_p", type=float, default=1.0)
|
42 |
args = parser.parse_args()
|
@@ -105,24 +104,18 @@ with Timer() as timer:
|
|
105 |
input = template.render(messages=messages, eos_token="")
|
106 |
input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True)
|
107 |
|
108 |
-
if args.debug:
|
109 |
-
print(input)
|
110 |
-
|
111 |
print(f"Encoded input in {timer.interval:.2f} seconds.")
|
112 |
|
113 |
with Timer() as timer:
|
114 |
-
max_new_tokens = config.max_seq_len - input_ids.shape[-1]
|
115 |
gen_settings = ExLlamaV2Sampler.Settings()
|
116 |
gen_settings.temperature = args.temperature
|
117 |
gen_settings.top_p = args.top_p
|
118 |
-
stop_conditions = ["<|SPEECH_GENERATION_END|>"]
|
119 |
|
120 |
job = ExLlamaV2DynamicJob(
|
121 |
input_ids=input_ids,
|
122 |
-
max_new_tokens=
|
123 |
gen_settings=gen_settings,
|
124 |
-
stop_conditions=
|
125 |
-
decode_special_tokens=True,
|
126 |
)
|
127 |
|
128 |
generator.enqueue(job)
|
@@ -131,11 +124,13 @@ with Timer() as timer:
|
|
131 |
while generator.num_remaining_jobs():
|
132 |
for result in generator.iterate():
|
133 |
if result.get("stage") == "streaming":
|
134 |
-
text = result.get("text"
|
135 |
-
output.append(text)
|
136 |
|
137 |
-
if
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
if result.get("eos"):
|
141 |
generator.clear_queue()
|
@@ -146,8 +141,7 @@ with Timer() as timer:
|
|
146 |
print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.")
|
147 |
|
148 |
with Timer() as timer:
|
149 |
-
output =
|
150 |
-
output = [int(o) for o in re.findall(r"<\|s_(\d+)\|>", output)]
|
151 |
output = torch.tensor([[output]]).cuda()
|
152 |
output = vocoder.decode_code(output)
|
153 |
output = output[0, 0, :]
|
|
|
5 |
|
6 |
# Sample Inference Script
|
7 |
```py
|
|
|
8 |
from argparse import ArgumentParser
|
9 |
|
10 |
import torch
|
|
|
34 |
parser.add_argument("-t", "--transcript", default="")
|
35 |
parser.add_argument("-o", "--output", default="output.wav")
|
36 |
parser.add_argument("-d", "--debug", action="store_true")
|
|
|
37 |
parser.add_argument("--sample_rate", type=int, default=16000)
|
38 |
+
parser.add_argument("--max_seq_len", type=int, default=2048)
|
39 |
parser.add_argument("--temperature", type=float, default=0.8)
|
40 |
parser.add_argument("--top_p", type=float, default=1.0)
|
41 |
args = parser.parse_args()
|
|
|
104 |
input = template.render(messages=messages, eos_token="")
|
105 |
input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True)
|
106 |
|
|
|
|
|
|
|
107 |
print(f"Encoded input in {timer.interval:.2f} seconds.")
|
108 |
|
109 |
with Timer() as timer:
|
|
|
110 |
gen_settings = ExLlamaV2Sampler.Settings()
|
111 |
gen_settings.temperature = args.temperature
|
112 |
gen_settings.top_p = args.top_p
|
|
|
113 |
|
114 |
job = ExLlamaV2DynamicJob(
|
115 |
input_ids=input_ids,
|
116 |
+
max_new_tokens=config.max_seq_len - input_ids.shape[-1],
|
117 |
gen_settings=gen_settings,
|
118 |
+
stop_conditions=["<|SPEECH_GENERATION_END|>"],
|
|
|
119 |
)
|
120 |
|
121 |
generator.enqueue(job)
|
|
|
124 |
while generator.num_remaining_jobs():
|
125 |
for result in generator.iterate():
|
126 |
if result.get("stage") == "streaming":
|
127 |
+
text = result.get("text")
|
|
|
128 |
|
129 |
+
if text:
|
130 |
+
output.append(text)
|
131 |
+
|
132 |
+
if args.debug:
|
133 |
+
print(text, end="", flush=True)
|
134 |
|
135 |
if result.get("eos"):
|
136 |
generator.clear_queue()
|
|
|
141 |
print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.")
|
142 |
|
143 |
with Timer() as timer:
|
144 |
+
output = [int(o[4:-2]) for o in output]
|
|
|
145 |
output = torch.tensor([[output]]).cuda()
|
146 |
output = vocoder.decode_code(output)
|
147 |
output = output[0, 0, :]
|