Annuvin commited on
Commit
fdc7ab6
·
verified ·
1 Parent(s): 3c1e441

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -16
README.md CHANGED
@@ -5,7 +5,6 @@ base_model:
5
 
6
  # Sample Inference Script
7
  ```py
8
- import re
9
  from argparse import ArgumentParser
10
 
11
  import torch
@@ -35,8 +34,8 @@ parser.add_argument("-a", "--audio", default="")
35
  parser.add_argument("-t", "--transcript", default="")
36
  parser.add_argument("-o", "--output", default="output.wav")
37
  parser.add_argument("-d", "--debug", action="store_true")
38
- parser.add_argument("--max_seq_len", type=int, default=2048)
39
  parser.add_argument("--sample_rate", type=int, default=16000)
 
40
  parser.add_argument("--temperature", type=float, default=0.8)
41
  parser.add_argument("--top_p", type=float, default=1.0)
42
  args = parser.parse_args()
@@ -105,24 +104,18 @@ with Timer() as timer:
105
  input = template.render(messages=messages, eos_token="")
106
  input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True)
107
 
108
- if args.debug:
109
- print(input)
110
-
111
  print(f"Encoded input in {timer.interval:.2f} seconds.")
112
 
113
  with Timer() as timer:
114
- max_new_tokens = config.max_seq_len - input_ids.shape[-1]
115
  gen_settings = ExLlamaV2Sampler.Settings()
116
  gen_settings.temperature = args.temperature
117
  gen_settings.top_p = args.top_p
118
- stop_conditions = ["<|SPEECH_GENERATION_END|>"]
119
 
120
  job = ExLlamaV2DynamicJob(
121
  input_ids=input_ids,
122
- max_new_tokens=max_new_tokens,
123
  gen_settings=gen_settings,
124
- stop_conditions=stop_conditions,
125
- decode_special_tokens=True,
126
  )
127
 
128
  generator.enqueue(job)
@@ -131,11 +124,13 @@ with Timer() as timer:
131
  while generator.num_remaining_jobs():
132
  for result in generator.iterate():
133
  if result.get("stage") == "streaming":
134
- text = result.get("text", "")
135
- output.append(text)
136
 
137
- if args.debug:
138
- print(text, end="", flush=True)
 
 
 
139
 
140
  if result.get("eos"):
141
  generator.clear_queue()
@@ -146,8 +141,7 @@ with Timer() as timer:
146
  print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.")
147
 
148
  with Timer() as timer:
149
- output = "".join(output)
150
- output = [int(o) for o in re.findall(r"<\|s_(\d+)\|>", output)]
151
  output = torch.tensor([[output]]).cuda()
152
  output = vocoder.decode_code(output)
153
  output = output[0, 0, :]
 
5
 
6
  # Sample Inference Script
7
  ```py
 
8
  from argparse import ArgumentParser
9
 
10
  import torch
 
34
  parser.add_argument("-t", "--transcript", default="")
35
  parser.add_argument("-o", "--output", default="output.wav")
36
  parser.add_argument("-d", "--debug", action="store_true")
 
37
  parser.add_argument("--sample_rate", type=int, default=16000)
38
+ parser.add_argument("--max_seq_len", type=int, default=2048)
39
  parser.add_argument("--temperature", type=float, default=0.8)
40
  parser.add_argument("--top_p", type=float, default=1.0)
41
  args = parser.parse_args()
 
104
  input = template.render(messages=messages, eos_token="")
105
  input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True)
106
 
 
 
 
107
  print(f"Encoded input in {timer.interval:.2f} seconds.")
108
 
109
  with Timer() as timer:
 
110
  gen_settings = ExLlamaV2Sampler.Settings()
111
  gen_settings.temperature = args.temperature
112
  gen_settings.top_p = args.top_p
 
113
 
114
  job = ExLlamaV2DynamicJob(
115
  input_ids=input_ids,
116
+ max_new_tokens=config.max_seq_len - input_ids.shape[-1],
117
  gen_settings=gen_settings,
118
+ stop_conditions=["<|SPEECH_GENERATION_END|>"],
 
119
  )
120
 
121
  generator.enqueue(job)
 
124
  while generator.num_remaining_jobs():
125
  for result in generator.iterate():
126
  if result.get("stage") == "streaming":
127
+ text = result.get("text")
 
128
 
129
+ if text:
130
+ output.append(text)
131
+
132
+ if args.debug:
133
+ print(text, end="", flush=True)
134
 
135
  if result.get("eos"):
136
  generator.clear_queue()
 
141
  print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.")
142
 
143
  with Timer() as timer:
144
+ output = [int(o[4:-2]) for o in output]
 
145
  output = torch.tensor([[output]]).cuda()
146
  output = vocoder.decode_code(output)
147
  output = output[0, 0, :]