Manmay commited on
Commit
840a99d
1 Parent(s): d08b1d4

added streaming faster inference with hifidecoder

Browse files
app.py CHANGED
@@ -73,31 +73,23 @@ def inference(
73
 
74
  start_time = time.time()
75
 
76
- all_parts = []
77
  for j, text in enumerate(texts):
78
- gen = tts.tts_with_preset(
79
  text,
80
  voice_samples=voice_samples,
81
  conditioning_latents=conditioning_latents,
82
  preset="ultra_fast",
83
  k=1
84
- )
85
-
86
- audio_ = gen.squeeze(0).cpu()
87
- all_parts.append(audio_)
88
-
89
- full_audio = torch.cat(all_parts, dim=-1)
90
-
91
- with open("Tortoise_TTS_Runs_Scripts.log", "a") as f:
92
- f.write(
93
- f"{datetime.now()} | Voice: {','.join(voices)} | Text: {text} | Time Taken (s): {time.time()-start_time} | Seed: {seed}\n"
94
- )
95
-
96
- output_texts = [f"({j+1}) {texts[j]}" for j in range(len(texts))]
97
-
98
- return ((24000, full_audio.squeeze().cpu().numpy()), "\n".join(output_texts))
99
-
100
-
101
  def main():
102
  title = "Tortoise TTS 🐢"
103
  description = """
@@ -130,9 +122,8 @@ def main():
130
  value="No",
131
  )
132
 
133
- output_audio = gr.Audio(label="Combined audio:")
134
- output_text = gr.Textbox(label="Split texts with indices:", lines=10)
135
-
136
  interface = gr.Interface(
137
  fn=inference,
138
  inputs=[
@@ -144,9 +135,9 @@ def main():
144
  ],
145
  title=title,
146
  description=description,
147
- outputs=[output_audio, output_text],
148
  )
149
- interface.launch()
150
 
151
 
152
  if __name__ == "__main__":
 
73
 
74
  start_time = time.time()
75
 
76
+ # all_parts = []
77
  for j, text in enumerate(texts):
78
+ for audio_frame in tts.tts_with_preset(
79
  text,
80
  voice_samples=voice_samples,
81
  conditioning_latents=conditioning_latents,
82
  preset="ultra_fast",
83
  k=1
84
+ ):
85
+ # print("Time taken: ", time.time() - start_time)
86
+ # all_parts.append(audio_frame)
87
+ yield (24000, audio_frame.cpu().detach().numpy())
88
+
89
+ # wav = torch.cat(all_parts, dim=0).unsqueeze(0)
90
+ # print(wav.shape)
91
+ # torchaudio.save("output.wav", wav.cpu(), 24000)
92
+ # yield (None, gr.make_waveform(audio="output.wav",))
 
 
 
 
 
 
 
 
93
  def main():
94
  title = "Tortoise TTS 🐢"
95
  description = """
 
122
  value="No",
123
  )
124
 
125
+ output_audio = gr.Audio(label="streaming audio:", streaming=True, autoplay=True)
126
+ # download_audio = gr.Audio(label="dowanload audio:")
 
127
  interface = gr.Interface(
128
  fn=inference,
129
  inputs=[
 
135
  ],
136
  title=title,
137
  description=description,
138
+ outputs=[output_audio],
139
  )
140
+ interface.queue().launch()
141
 
142
 
143
  if __name__ == "__main__":
tortoise/api.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import torch.nn.functional as F
9
  import progressbar
10
  import torchaudio
11
-
12
  from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
13
  from tortoise.models.diffusion_decoder import DiffusionTts
14
  from tortoise.models.autoregressive import UnifiedVoice
@@ -16,6 +16,7 @@ from tqdm import tqdm
16
  from tortoise.models.arch_util import TorchMelSpectrogram
17
  from tortoise.models.clvp import CLVP
18
  from tortoise.models.cvvp import CVVP
 
19
  from tortoise.models.random_latent_generator import RandomLatentConverter
20
  from tortoise.models.vocoder import UnivNetGenerator
21
  from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
@@ -23,19 +24,18 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named
23
  from tortoise.utils.tokenizer import VoiceBpeTokenizer
24
  from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
25
  from contextlib import contextmanager
 
 
26
  pbar = None
27
-
28
  DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
29
  MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
 
30
  MODELS = {
31
- 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
32
- 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
33
- 'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth',
34
- 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth',
35
- 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth',
36
- 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
37
- 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
38
- 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
39
  }
40
 
41
  def download_models(specific_models=None):
@@ -64,6 +64,7 @@ def download_models(specific_models=None):
64
  continue
65
  print(f'Downloading {model_name} from {url}...')
66
  request.urlretrieve(url, model_path, show_progress)
 
67
  print('Done.')
68
 
69
 
@@ -238,7 +239,6 @@ class TextToSpeech:
238
  if os.path.exists(f'{models_dir}/autoregressive.ptt'):
239
  # Assume this is a traced directory.
240
  self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
241
- self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
242
  else:
243
  self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
244
  model_dim=1024,
@@ -246,19 +246,15 @@ class TextToSpeech:
246
  train_solo_embeddings=False).cuda().eval()
247
  self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
248
  self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
249
-
250
- self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
251
- in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
252
- layer_drop=0, unconditioned_percentage=0).cuda().eval()
253
- self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir)))
254
-
255
- self.vocoder = UnivNetGenerator().cuda()
256
- self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
257
- self.vocoder.eval(inference=True)
258
 
 
 
 
 
 
 
259
  # Random latent generators (RLGs) are loaded lazily.
260
  self.rlg_auto = None
261
- self.rlg_diffusion = None
262
  def get_conditioning_latents(self, voice_samples, return_mels=False):
263
  """
264
  Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
@@ -277,31 +273,18 @@ class TextToSpeech:
277
  auto_conds = torch.stack(auto_conds, dim=1)
278
  auto_latent = self.autoregressive.get_conditioning(auto_conds)
279
 
280
- diffusion_conds = []
281
- for sample in voice_samples:
282
- # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
283
- sample = torchaudio.functional.resample(sample, 22050, 24000)
284
- sample = pad_or_truncate(sample, 102400)
285
- cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
286
- diffusion_conds.append(cond_mel)
287
- diffusion_conds = torch.stack(diffusion_conds, dim=1)
288
-
289
- diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
290
-
291
  if return_mels:
292
- return auto_latent, diffusion_latent, auto_conds, diffusion_conds
293
  else:
294
- return auto_latent, diffusion_latent
295
 
296
  def get_random_conditioning_latents(self):
297
  # Lazy-load the RLG models.
298
  if self.rlg_auto is None:
299
  self.rlg_auto = RandomLatentConverter(1024).eval()
300
  self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
301
- self.rlg_diffusion = RandomLatentConverter(2048).eval()
302
- self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
303
  with torch.no_grad():
304
- return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
305
 
306
  def tts_with_preset(self, text, preset='fast', **kwargs):
307
  """
@@ -317,17 +300,33 @@ class TextToSpeech:
317
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0}
318
  # Presets are defined here.
319
  presets = {
320
- 'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 15},
321
  'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50},
322
  'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
323
  'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
324
  }
325
  settings.update(presets[preset])
326
  settings.update(kwargs) # allow overriding of preset settings with kwargs
327
- return self.tts(text, **settings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
330
- return_deterministic_state=False,
331
  # autoregressive generation parameters follow
332
  num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
333
  # CVVP parameters follow
@@ -353,13 +352,6 @@ class TextToSpeech:
353
  of long silences or "uhhhhhhs", etc.
354
  :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
355
  :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
356
- :param typical_sampling: Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666
357
- I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
358
- could use some tuning.
359
- :param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
360
- ~~CLVP-CVVP KNOBS~~
361
- :param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
362
- [0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
363
  ~~DIFFUSION KNOBS~~
364
  :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
365
  the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
@@ -385,17 +377,11 @@ class TextToSpeech:
385
  text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
386
  text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
387
  assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
388
- auto_conds = None
389
  if voice_samples is not None:
390
- auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True)
391
- elif conditioning_latents is not None:
392
- auto_conditioning, diffusion_conditioning = conditioning_latents
393
  else:
394
- auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
395
  auto_conditioning = auto_conditioning.to(self.device)
396
- diffusion_conditioning = diffusion_conditioning.to(self.device)
397
-
398
- diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
399
 
400
  with torch.no_grad():
401
  calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
@@ -404,58 +390,52 @@ class TextToSpeech:
404
  with torch.autocast(
405
  device_type="cuda" , dtype=torch.float16, enabled=self.half
406
  ):
407
- codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
408
- do_sample=True,
409
- top_p=top_p,
410
- temperature=temperature,
411
- num_return_sequences=num_autoregressive_samples,
412
- length_penalty=length_penalty,
413
- repetition_penalty=repetition_penalty,
414
- max_generate_length=max_mel_tokens,
415
- **hf_generate_kwargs)
416
- # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
417
- # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
418
- # results, but will increase memory usage.
419
- best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
420
- torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
421
- torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
422
- return_latent=True, clip_inputs=False)
423
- del auto_conditioning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
- if verbose:
426
- print("Transforming autoregressive outputs into audio..")
427
- wav_candidates = []
428
- latents = best_latents
429
- # Find the first occurrence of the "calm" token and trim the codes to that.
430
- ctokens = 0
431
- for k in range(codes.shape[-1]):
432
- if codes[0, k] == calm_token:
433
- ctokens += 1
434
- else:
435
- ctokens = 0
436
- if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
437
- latents = latents[:, :k]
438
- break
439
- mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
440
- verbose=verbose)
441
- wav = self.vocoder.inference(mel)
442
- wav_candidates.append(wav.cpu())
443
-
444
- def potentially_redact(clip, text):
445
- if self.enable_redaction:
446
- return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
447
- return clip
448
- wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
449
-
450
- if len(wav_candidates) > 1:
451
- res = wav_candidates
452
- else:
453
- res = wav_candidates[0]
454
-
455
- if return_deterministic_state:
456
- return res, (deterministic_seed, text, voice_samples, conditioning_latents)
457
- else:
458
- return res
459
  def deterministic_state(self, seed=None):
460
  """
461
  Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
 
8
  import torch.nn.functional as F
9
  import progressbar
10
  import torchaudio
11
+ import numpy as np
12
  from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
13
  from tortoise.models.diffusion_decoder import DiffusionTts
14
  from tortoise.models.autoregressive import UnifiedVoice
 
16
  from tortoise.models.arch_util import TorchMelSpectrogram
17
  from tortoise.models.clvp import CLVP
18
  from tortoise.models.cvvp import CVVP
19
+ from tortoise.models.hifigan_decoder import HifiganGenerator
20
  from tortoise.models.random_latent_generator import RandomLatentConverter
21
  from tortoise.models.vocoder import UnivNetGenerator
22
  from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
 
24
  from tortoise.utils.tokenizer import VoiceBpeTokenizer
25
  from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
26
  from contextlib import contextmanager
27
+ from tortoise.models.stream_generator import init_stream_support
28
+ # from huggingface_hub import hf_hub_download
29
  pbar = None
30
+ init_stream_support()
31
  DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
32
  MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
33
+
34
  MODELS = {
35
+ 'autoregressive.pth': 'https://huggingface.co/Manmay/tortoise-tts/blob/main/autoregressive.pth',
36
+ 'classifier.pth': 'https://huggingface.co/Manmay/tortoise-tts/blob/main/classifier.pth',
37
+ 'rlg_auto.pth': 'https://huggingface.co/Manmay/tortoise-tts/blob/main/rlg_auto.pth',
38
+ 'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/blob/main/hifidecoder.pth',
 
 
 
 
39
  }
40
 
41
  def download_models(specific_models=None):
 
64
  continue
65
  print(f'Downloading {model_name} from {url}...')
66
  request.urlretrieve(url, model_path, show_progress)
67
+ # hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
68
  print('Done.')
69
 
70
 
 
239
  if os.path.exists(f'{models_dir}/autoregressive.ptt'):
240
  # Assume this is a traced directory.
241
  self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
 
242
  else:
243
  self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
244
  model_dim=1024,
 
246
  train_solo_embeddings=False).cuda().eval()
247
  self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
248
  self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
 
 
 
 
 
 
 
 
 
249
 
250
+ self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
251
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
252
+ upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
253
+ cond_channels=1024).cuda().eval()
254
+ hifi_model = torch.load(get_model_path('hifidecoder.pth'))
255
+ self.hifi_decoder.load_state_dict(hifi_model, strict=False)
256
  # Random latent generators (RLGs) are loaded lazily.
257
  self.rlg_auto = None
 
258
  def get_conditioning_latents(self, voice_samples, return_mels=False):
259
  """
260
  Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
 
273
  auto_conds = torch.stack(auto_conds, dim=1)
274
  auto_latent = self.autoregressive.get_conditioning(auto_conds)
275
 
 
 
 
 
 
 
 
 
 
 
 
276
  if return_mels:
277
+ return auto_latent
278
  else:
279
+ return auto_latent
280
 
281
  def get_random_conditioning_latents(self):
282
  # Lazy-load the RLG models.
283
  if self.rlg_auto is None:
284
  self.rlg_auto = RandomLatentConverter(1024).eval()
285
  self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
 
 
286
  with torch.no_grad():
287
+ return self.rlg_auto(torch.tensor([0.0]))
288
 
289
  def tts_with_preset(self, text, preset='fast', **kwargs):
290
  """
 
300
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0}
301
  # Presets are defined here.
302
  presets = {
303
+ 'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 10},
304
  'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50},
305
  'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
306
  'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
307
  }
308
  settings.update(presets[preset])
309
  settings.update(kwargs) # allow overriding of preset settings with kwargs
310
+ for audio_frame in self.tts(text, **settings):
311
+ yield audio_frame
312
+
313
+ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
314
+ """Handle chunk formatting in streaming mode"""
315
+ wav_chunk = wav_gen[:-overlap_len]
316
+ if wav_gen_prev is not None:
317
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
318
+ if wav_overlap is not None:
319
+ crossfade_wav = wav_chunk[:overlap_len]
320
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
321
+ wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
322
+ wav_chunk[:overlap_len] += crossfade_wav
323
+ wav_overlap = wav_gen[-overlap_len:]
324
+ wav_gen_prev = wav_gen
325
+ return wav_chunk, wav_gen_prev, wav_overlap
326
+
327
 
328
  def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
329
+ return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40,
330
  # autoregressive generation parameters follow
331
  num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
332
  # CVVP parameters follow
 
352
  of long silences or "uhhhhhhs", etc.
353
  :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
354
  :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
 
 
 
 
 
 
 
355
  ~~DIFFUSION KNOBS~~
356
  :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
357
  the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
 
377
  text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
378
  text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
379
  assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
 
380
  if voice_samples is not None:
381
+ auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False)
 
 
382
  else:
383
+ auto_conditioning = self.get_random_conditioning_latents()
384
  auto_conditioning = auto_conditioning.to(self.device)
 
 
 
385
 
386
  with torch.no_grad():
387
  calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
 
390
  with torch.autocast(
391
  device_type="cuda" , dtype=torch.float16, enabled=self.half
392
  ):
393
+ fake_inputs = self.autoregressive.compute_embeddings(
394
+ auto_conditioning,
395
+ text_tokens,
396
+ )
397
+ gpt_generator = self.autoregressive.get_generator(
398
+ fake_inputs=fake_inputs,
399
+ top_k=50,
400
+ top_p=top_p,
401
+ temperature=temperature,
402
+ do_sample=True,
403
+ num_beams=1,
404
+ num_return_sequences=1,
405
+ length_penalty=float(length_penalty),
406
+ repetition_penalty=float(repetition_penalty),
407
+ output_attentions=False,
408
+ output_hidden_states=True,
409
+ **hf_generate_kwargs,
410
+ )
411
+ all_latents = []
412
+ codes_ = []
413
+ wav_gen_prev = None
414
+ wav_overlap = None
415
+ is_end = False
416
+ first_buffer = 60
417
+ while not is_end:
418
+ try:
419
+ with torch.autocast(
420
+ device_type="cuda", dtype=torch.float16, enabled=self.half
421
+ ):
422
+ codes, latent = next(gpt_generator)
423
+ all_latents += [latent]
424
+ codes_ += [codes]
425
+ except StopIteration:
426
+ is_end = True
427
+
428
+ if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
429
+ first_buffer = 0
430
+ gpt_latents = torch.cat(all_latents, dim=0)[None, :]
431
+ wav_gen = self.hifi_decoder.inference(gpt_latents.cuda(), auto_conditioning)
432
+ wav_gen = wav_gen.squeeze()
433
+ wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
434
+ wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
435
+ )
436
+ codes_ = []
437
+ yield wav_chunk
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def deterministic_state(self, seed=None):
440
  """
441
  Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
tortoise/models/autoregressive.py CHANGED
@@ -38,6 +38,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
38
  self.transformer = gpt
39
  self.text_pos_embedding = text_pos_emb
40
  self.embeddings = embeddings
 
41
  self.lm_head = nn.Sequential(norm, linear)
42
  self.kv_cache = kv_cache
43
 
@@ -509,7 +510,28 @@ class UnifiedVoice(nn.Module):
509
  loss_text = F.cross_entropy(text_logits, text_targets.long())
510
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
511
  return loss_text.mean(), loss_mel.mean(), mel_logits
512
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
514
  max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
515
 
@@ -540,7 +562,16 @@ class UnifiedVoice(nn.Module):
540
  num_return_sequences=num_return_sequences, **hf_generate_kwargs)
541
  return gen[:, trunc_index:]
542
 
543
-
 
 
 
 
 
 
 
 
 
544
  if __name__ == '__main__':
545
  gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
546
  l = gpt(torch.randn(2, 3, 80, 800),
 
38
  self.transformer = gpt
39
  self.text_pos_embedding = text_pos_emb
40
  self.embeddings = embeddings
41
+ self.final_norm = norm
42
  self.lm_head = nn.Sequential(norm, linear)
43
  self.kv_cache = kv_cache
44
 
 
510
  loss_text = F.cross_entropy(text_logits, text_targets.long())
511
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
512
  return loss_text.mean(), loss_mel.mean(), mel_logits
513
+ def compute_embeddings(
514
+ self,
515
+ cond_latents,
516
+ text_inputs,
517
+ ):
518
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
519
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
520
+ emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
521
+ conds = cond_latents.unsqueeze(1)
522
+ emb = torch.cat([conds, emb], dim=1)
523
+ self.inference_model.store_mel_emb(emb)
524
+ gpt_inputs = torch.full(
525
+ (
526
+ emb.shape[0],
527
+ emb.shape[1] + 1, # +1 for the start_mel_token
528
+ ),
529
+ fill_value=1,
530
+ dtype=torch.long,
531
+ device=text_inputs.device,
532
+ )
533
+ gpt_inputs[:, -1] = self.start_mel_token
534
+ return gpt_inputs
535
  def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
536
  max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
537
 
 
562
  num_return_sequences=num_return_sequences, **hf_generate_kwargs)
563
  return gen[:, trunc_index:]
564
 
565
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
566
+ return self.inference_model.generate_stream(
567
+ fake_inputs,
568
+ bos_token_id=self.start_mel_token,
569
+ pad_token_id=self.stop_mel_token,
570
+ eos_token_id=self.stop_mel_token,
571
+ max_length=500,
572
+ do_stream=True,
573
+ **hf_generate_kwargs,
574
+ )
575
  if __name__ == '__main__':
576
  gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
577
  l = gpt(torch.randn(2, 3, 80, 800),
tortoise/models/hifigan_decoder.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn import functional as F
6
+ from torch.nn.utils import remove_weight_norm, weight_norm
7
+
8
+ LRELU_SLOPE = 0.1
9
+
10
+
11
+ def get_padding(k, d):
12
+ return int((k * d - d) / 2)
13
+
14
+
15
+ class ResBlock1(torch.nn.Module):
16
+ """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
17
+
18
+ Network::
19
+
20
+ x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
21
+ |--------------------------------------------------------------------------------------------------|
22
+
23
+
24
+ Args:
25
+ channels (int): number of hidden channels for the convolutional layers.
26
+ kernel_size (int): size of the convolution filter in each layer.
27
+ dilations (list): list of dilation value for each conv layer in a block.
28
+ """
29
+
30
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
31
+ super().__init__()
32
+ self.convs1 = nn.ModuleList(
33
+ [
34
+ weight_norm(
35
+ Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ 1,
40
+ dilation=dilation[0],
41
+ padding=get_padding(kernel_size, dilation[0]),
42
+ )
43
+ ),
44
+ weight_norm(
45
+ Conv1d(
46
+ channels,
47
+ channels,
48
+ kernel_size,
49
+ 1,
50
+ dilation=dilation[1],
51
+ padding=get_padding(kernel_size, dilation[1]),
52
+ )
53
+ ),
54
+ weight_norm(
55
+ Conv1d(
56
+ channels,
57
+ channels,
58
+ kernel_size,
59
+ 1,
60
+ dilation=dilation[2],
61
+ padding=get_padding(kernel_size, dilation[2]),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+
67
+ self.convs2 = nn.ModuleList(
68
+ [
69
+ weight_norm(
70
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
71
+ ),
72
+ weight_norm(
73
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
74
+ ),
75
+ weight_norm(
76
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
77
+ ),
78
+ ]
79
+ )
80
+
81
+ def forward(self, x):
82
+ """
83
+ Args:
84
+ x (Tensor): input tensor.
85
+ Returns:
86
+ Tensor: output tensor.
87
+ Shapes:
88
+ x: [B, C, T]
89
+ """
90
+ for c1, c2 in zip(self.convs1, self.convs2):
91
+ xt = F.leaky_relu(x, LRELU_SLOPE)
92
+ xt = c1(xt)
93
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
94
+ xt = c2(xt)
95
+ x = xt + x
96
+ return x
97
+
98
+ def remove_weight_norm(self):
99
+ for l in self.convs1:
100
+ remove_weight_norm(l)
101
+ for l in self.convs2:
102
+ remove_weight_norm(l)
103
+
104
+
105
+ class ResBlock2(torch.nn.Module):
106
+ """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
107
+
108
+ Network::
109
+
110
+ x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
111
+ |---------------------------------------------------|
112
+
113
+
114
+ Args:
115
+ channels (int): number of hidden channels for the convolutional layers.
116
+ kernel_size (int): size of the convolution filter in each layer.
117
+ dilations (list): list of dilation value for each conv layer in a block.
118
+ """
119
+
120
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
121
+ super().__init__()
122
+ self.convs = nn.ModuleList(
123
+ [
124
+ weight_norm(
125
+ Conv1d(
126
+ channels,
127
+ channels,
128
+ kernel_size,
129
+ 1,
130
+ dilation=dilation[0],
131
+ padding=get_padding(kernel_size, dilation[0]),
132
+ )
133
+ ),
134
+ weight_norm(
135
+ Conv1d(
136
+ channels,
137
+ channels,
138
+ kernel_size,
139
+ 1,
140
+ dilation=dilation[1],
141
+ padding=get_padding(kernel_size, dilation[1]),
142
+ )
143
+ ),
144
+ ]
145
+ )
146
+
147
+ def forward(self, x):
148
+ for c in self.convs:
149
+ xt = F.leaky_relu(x, LRELU_SLOPE)
150
+ xt = c(xt)
151
+ x = xt + x
152
+ return x
153
+
154
+ def remove_weight_norm(self):
155
+ for l in self.convs:
156
+ remove_weight_norm(l)
157
+
158
+
159
+ class HifiganGenerator(torch.nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ resblock_type,
165
+ resblock_dilation_sizes,
166
+ resblock_kernel_sizes,
167
+ upsample_kernel_sizes,
168
+ upsample_initial_channel,
169
+ upsample_factors,
170
+ inference_padding=5,
171
+ cond_channels=0,
172
+ conv_pre_weight_norm=True,
173
+ conv_post_weight_norm=True,
174
+ conv_post_bias=True,
175
+ ):
176
+ r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
177
+
178
+ Network:
179
+ x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
180
+ .. -> zI ---|
181
+ resblockN_kNx1 -> zN ---'
182
+
183
+ Args:
184
+ in_channels (int): number of input tensor channels.
185
+ out_channels (int): number of output tensor channels.
186
+ resblock_type (str): type of the `ResBlock`. '1' or '2'.
187
+ resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
188
+ resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
189
+ upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
190
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
191
+ for each consecutive upsampling layer.
192
+ upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
193
+ inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
194
+ """
195
+ super().__init__()
196
+ self.inference_padding = inference_padding
197
+ self.num_kernels = len(resblock_kernel_sizes)
198
+ self.num_upsamples = len(upsample_factors)
199
+ # initial upsampling layers
200
+ self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
201
+ resblock = ResBlock1 if resblock_type == "1" else ResBlock2
202
+ # upsampling layers
203
+ self.ups = nn.ModuleList()
204
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
205
+ self.ups.append(
206
+ weight_norm(
207
+ ConvTranspose1d(
208
+ upsample_initial_channel // (2**i),
209
+ upsample_initial_channel // (2 ** (i + 1)),
210
+ k,
211
+ u,
212
+ padding=(k - u) // 2,
213
+ )
214
+ )
215
+ )
216
+ # MRF blocks
217
+ self.resblocks = nn.ModuleList()
218
+ for i in range(len(self.ups)):
219
+ ch = upsample_initial_channel // (2 ** (i + 1))
220
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
221
+ self.resblocks.append(resblock(ch, k, d))
222
+ # post convolution layer
223
+ self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
224
+ if cond_channels > 0:
225
+ self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
226
+
227
+ if not conv_pre_weight_norm:
228
+ remove_weight_norm(self.conv_pre)
229
+
230
+ if not conv_post_weight_norm:
231
+ remove_weight_norm(self.conv_post)
232
+
233
+ def forward(self, x, g=None):
234
+ """
235
+ Args:
236
+ x (Tensor): feature input tensor.
237
+ g (Tensor): global conditioning input tensor.
238
+
239
+ Returns:
240
+ Tensor: output waveform.
241
+
242
+ Shapes:
243
+ x: [B, C, T]
244
+ Tensor: [B, 1, T]
245
+ """
246
+ o = self.conv_pre(x)
247
+ if hasattr(self, "cond_layer"):
248
+ o = o + self.cond_layer(g)
249
+ for i in range(self.num_upsamples):
250
+ o = F.leaky_relu(o, LRELU_SLOPE)
251
+ o = self.ups[i](o)
252
+ z_sum = None
253
+ for j in range(self.num_kernels):
254
+ if z_sum is None:
255
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
256
+ else:
257
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
258
+ o = z_sum / self.num_kernels
259
+ o = F.leaky_relu(o)
260
+ o = self.conv_post(o)
261
+ o = torch.tanh(o)
262
+ return o
263
+
264
+ @torch.no_grad()
265
+ def inference(self, c, g=None):
266
+ """
267
+ Args:
268
+ x (Tensor): conditioning input tensor.
269
+
270
+ Returns:
271
+ Tensor: output waveform.
272
+
273
+ Shapes:
274
+ x: [B, C, T]
275
+ Tensor: [B, 1, T]
276
+ """
277
+ # c = c.to(self.conv_pre.weight.device)
278
+ # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
279
+ up_1 = torch.nn.functional.interpolate(
280
+ c.transpose(1,2),
281
+ scale_factor=[1024 / 256],
282
+ mode="linear",
283
+ )
284
+ up_2 = torch.nn.functional.interpolate(
285
+ up_1,
286
+ scale_factor=[24000 / 22050],
287
+ mode="linear",
288
+ )
289
+ g = g.unsqueeze(0)
290
+ return self.forward(up_2.to("cuda"), g.transpose(1,2))
291
+
292
+ def remove_weight_norm(self):
293
+ print("Removing weight norm...")
294
+ for l in self.ups:
295
+ remove_weight_norm(l)
296
+ for l in self.resblocks:
297
+ l.remove_weight_norm()
298
+ remove_weight_norm(self.conv_pre)
299
+ remove_weight_norm(self.conv_post)
tortoise/models/stream_generator.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/LowinLi/transformers-stream-generator
2
+
3
+ from transformers import (
4
+ GenerationConfig,
5
+ GenerationMixin,
6
+ LogitsProcessorList,
7
+ StoppingCriteriaList,
8
+ DisjunctiveConstraint,
9
+ BeamSearchScorer,
10
+ PhrasalConstraint,
11
+ ConstrainedBeamSearchScorer,
12
+ PreTrainedModel,
13
+ )
14
+ import numpy as np
15
+ import random
16
+ import warnings
17
+ import inspect
18
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
19
+ import torch
20
+ from typing import Callable, List, Optional, Union
21
+ from torch import nn
22
+ import torch.distributed as dist
23
+ import copy
24
+
25
+
26
+ def setup_seed(seed):
27
+ if seed == -1:
28
+ return
29
+ torch.manual_seed(seed)
30
+ if torch.cuda.is_available():
31
+ torch.cuda.manual_seed_all(seed)
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+
36
+
37
+ class StreamGenerationConfig(GenerationConfig):
38
+ def __init__(self, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.do_stream = kwargs.pop("do_stream", False)
41
+
42
+
43
+ class NewGenerationMixin(GenerationMixin):
44
+ @torch.no_grad()
45
+ def generate(
46
+ self,
47
+ inputs: Optional[torch.Tensor] = None,
48
+ generation_config: Optional[StreamGenerationConfig] = None,
49
+ logits_processor: Optional[LogitsProcessorList] = None,
50
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
51
+ prefix_allowed_tokens_fn: Optional[
52
+ Callable[[int, torch.Tensor], List[int]]
53
+ ] = None,
54
+ synced_gpus: Optional[bool] = False,
55
+ seed=0,
56
+ **kwargs,
57
+ ) -> Union[GenerateOutput, torch.LongTensor]:
58
+ r"""
59
+
60
+ Generates sequences of token ids for models with a language modeling head.
61
+
62
+ <Tip warning={true}>
63
+
64
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
65
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
66
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
67
+
68
+ For an overview of generation strategies and code examples, check out the [following
69
+ guide](./generation_strategies).
70
+
71
+ </Tip>
72
+
73
+ Parameters:
74
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
75
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
76
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
77
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
78
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
79
+ generation_config (`~generation.GenerationConfig`, *optional*):
80
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
81
+ passed to generate matching the attributes of `generation_config` will override them. If
82
+ `generation_config` is not provided, the default will be used, which had the following loading
83
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
84
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
85
+ default values, whose documentation should be checked to parameterize generation.
86
+ logits_processor (`LogitsProcessorList`, *optional*):
87
+ Custom logits processors that complement the default logits processors built from arguments and
88
+ generation config. If a logit processor is passed that is already created with the arguments or a
89
+ generation config an error is thrown. This feature is intended for advanced users.
90
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
91
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
92
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
93
+ generation config an error is thrown. This feature is intended for advanced users.
94
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
95
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
96
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
97
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
98
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
99
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
100
+ Retrieval](https://arxiv.org/abs/2010.00904).
101
+ synced_gpus (`bool`, *optional*, defaults to `False`):
102
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
103
+ kwargs:
104
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
105
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
106
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
107
+
108
+ Return:
109
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
110
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
111
+
112
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
113
+ [`~utils.ModelOutput`] types are:
114
+
115
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
116
+ - [`~generation.SampleDecoderOnlyOutput`],
117
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
118
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
119
+
120
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
121
+ [`~utils.ModelOutput`] types are:
122
+
123
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
124
+ - [`~generation.SampleEncoderDecoderOutput`],
125
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
126
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
127
+ """
128
+ setup_seed(seed)
129
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
130
+ self._validate_model_class()
131
+
132
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
133
+ if generation_config is None:
134
+ # legacy: users may modify the model configuration to control generation -- update the generation config
135
+ # model attribute accordingly, if it was created from the model config
136
+ if self.generation_config._from_model_config:
137
+ new_generation_config = StreamGenerationConfig.from_model_config(
138
+ self.config
139
+ )
140
+ if new_generation_config != self.generation_config:
141
+ warnings.warn(
142
+ "You have modified the pretrained model configuration to control generation. This is a"
143
+ " deprecated strategy to control generation and will be removed soon, in a future version."
144
+ " Please use a generation configuration file (see"
145
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
146
+ )
147
+ self.generation_config = new_generation_config
148
+ generation_config = self.generation_config
149
+
150
+ generation_config = copy.deepcopy(generation_config)
151
+ model_kwargs = generation_config.update(
152
+ **kwargs
153
+ ) # All unused kwargs must be model kwargs
154
+ # self._validate_model_kwargs(model_kwargs.copy())
155
+
156
+ # 2. Set generation parameters if not already defined
157
+ logits_processor = (
158
+ logits_processor if logits_processor is not None else LogitsProcessorList()
159
+ )
160
+ stopping_criteria = (
161
+ stopping_criteria
162
+ if stopping_criteria is not None
163
+ else StoppingCriteriaList()
164
+ )
165
+
166
+ if (
167
+ generation_config.pad_token_id is None
168
+ and generation_config.eos_token_id is not None
169
+ ):
170
+ if model_kwargs.get("attention_mask", None) is None:
171
+ logger.warning(
172
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
173
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
174
+ )
175
+ eos_token_id = generation_config.eos_token_id
176
+ if isinstance(eos_token_id, list):
177
+ eos_token_id = eos_token_id[0]
178
+ logger.warning(
179
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
180
+ )
181
+ generation_config.pad_token_id = eos_token_id
182
+
183
+ # 3. Define model inputs
184
+ # inputs_tensor has to be defined
185
+ # model_input_name is defined if model-specific keyword input is passed
186
+ # otherwise model_input_name is None
187
+ # all model-specific keyword inputs are removed from `model_kwargs`
188
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
189
+ inputs, generation_config.bos_token_id, model_kwargs
190
+ )
191
+ batch_size = inputs_tensor.shape[0]
192
+
193
+ # 4. Define other model kwargs
194
+ model_kwargs["output_attentions"] = generation_config.output_attentions
195
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
196
+ model_kwargs["use_cache"] = generation_config.use_cache
197
+
198
+ accepts_attention_mask = "attention_mask" in set(
199
+ inspect.signature(self.forward).parameters.keys()
200
+ )
201
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
202
+
203
+ if (
204
+ model_kwargs.get("attention_mask", None) is None
205
+ and requires_attention_mask
206
+ and accepts_attention_mask
207
+ ):
208
+ model_kwargs[
209
+ "attention_mask"
210
+ ] = self._prepare_attention_mask_for_generation(
211
+ inputs_tensor,
212
+ generation_config.pad_token_id,
213
+ generation_config.eos_token_id,
214
+ )
215
+
216
+ # decoder-only models should use left-padding for generation
217
+ if not self.config.is_encoder_decoder:
218
+ if (
219
+ generation_config.pad_token_id is not None
220
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
221
+ > 0
222
+ ):
223
+ logger.warning(
224
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
225
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
226
+ )
227
+
228
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
229
+ # if model is encoder decoder encoder_outputs are created
230
+ # and added to `model_kwargs`
231
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
232
+ inputs_tensor, model_kwargs, model_input_name
233
+ )
234
+
235
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
236
+ if self.config.is_encoder_decoder:
237
+ input_ids = self._prepare_decoder_input_ids_for_generation(
238
+ batch_size,
239
+ decoder_start_token_id=generation_config.decoder_start_token_id,
240
+ bos_token_id=generation_config.bos_token_id,
241
+ model_kwargs=model_kwargs,
242
+ device=inputs_tensor.device,
243
+ )
244
+ else:
245
+ # if decoder-only then inputs_tensor has to be `input_ids`
246
+ input_ids = inputs_tensor
247
+
248
+ # 6. Prepare `max_length` depending on other stopping criteria.
249
+ input_ids_seq_length = input_ids.shape[-1]
250
+ has_default_max_length = (
251
+ kwargs.get("max_length") is None
252
+ and generation_config.max_length is not None
253
+ )
254
+ if has_default_max_length and generation_config.max_new_tokens is None:
255
+ warnings.warn(
256
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
257
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
258
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
259
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
260
+ UserWarning,
261
+ )
262
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
263
+ generation_config.max_length = (
264
+ generation_config.max_new_tokens + input_ids_seq_length
265
+ )
266
+ elif (
267
+ not has_default_max_length and generation_config.max_new_tokens is not None
268
+ ):
269
+ raise ValueError(
270
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
271
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
272
+ " documentation for more information. "
273
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
274
+ )
275
+
276
+ if (
277
+ generation_config.min_length is not None
278
+ and generation_config.min_length > generation_config.max_length
279
+ ):
280
+ raise ValueError(
281
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
282
+ f" the maximum length ({generation_config.max_length})"
283
+ )
284
+ if input_ids_seq_length >= generation_config.max_length:
285
+ input_ids_string = (
286
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
287
+ )
288
+ logger.warning(
289
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
290
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
291
+ " increasing `max_new_tokens`."
292
+ )
293
+
294
+ # 7. determine generation mode
295
+ is_constraint_gen_mode = (
296
+ generation_config.constraints is not None
297
+ or generation_config.force_words_ids is not None
298
+ )
299
+
300
+ is_contrastive_search_gen_mode = (
301
+ generation_config.top_k is not None
302
+ and generation_config.top_k > 1
303
+ and generation_config.do_sample is False
304
+ and generation_config.penalty_alpha is not None
305
+ and generation_config.penalty_alpha > 0
306
+ )
307
+
308
+ is_greedy_gen_mode = (
309
+ (generation_config.num_beams == 1)
310
+ and (generation_config.num_beam_groups == 1)
311
+ and generation_config.do_sample is False
312
+ and not is_constraint_gen_mode
313
+ and not is_contrastive_search_gen_mode
314
+ )
315
+ is_sample_gen_mode = (
316
+ (generation_config.num_beams == 1)
317
+ and (generation_config.num_beam_groups == 1)
318
+ and generation_config.do_sample is True
319
+ and generation_config.do_stream is False
320
+ and not is_constraint_gen_mode
321
+ and not is_contrastive_search_gen_mode
322
+ )
323
+ is_sample_gen_stream_mode = (
324
+ (generation_config.num_beams == 1)
325
+ and (generation_config.num_beam_groups == 1)
326
+ and generation_config.do_stream is True
327
+ and not is_constraint_gen_mode
328
+ and not is_contrastive_search_gen_mode
329
+ )
330
+ is_beam_gen_mode = (
331
+ (generation_config.num_beams > 1)
332
+ and (generation_config.num_beam_groups == 1)
333
+ and generation_config.do_sample is False
334
+ and not is_constraint_gen_mode
335
+ and not is_contrastive_search_gen_mode
336
+ )
337
+ is_beam_sample_gen_mode = (
338
+ (generation_config.num_beams > 1)
339
+ and (generation_config.num_beam_groups == 1)
340
+ and generation_config.do_sample is True
341
+ and not is_constraint_gen_mode
342
+ and not is_contrastive_search_gen_mode
343
+ )
344
+ is_group_beam_gen_mode = (
345
+ (generation_config.num_beams > 1)
346
+ and (generation_config.num_beam_groups > 1)
347
+ and not is_constraint_gen_mode
348
+ and not is_contrastive_search_gen_mode
349
+ )
350
+
351
+ if generation_config.num_beam_groups > generation_config.num_beams:
352
+ raise ValueError(
353
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
354
+ )
355
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
356
+ raise ValueError(
357
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
358
+ )
359
+
360
+ if self.device.type != input_ids.device.type:
361
+ warnings.warn(
362
+ "You are calling .generate() with the `input_ids` being on a device type different"
363
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
364
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
365
+ " Please make sure that you have put `input_ids` to the"
366
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
367
+ " running `.generate()`.",
368
+ UserWarning,
369
+ )
370
+ # 8. prepare distribution pre_processing samplers
371
+ logits_processor = self._get_logits_processor(
372
+ generation_config=generation_config,
373
+ input_ids_seq_length=input_ids_seq_length,
374
+ encoder_input_ids=inputs_tensor,
375
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
376
+ logits_processor=logits_processor,
377
+ )
378
+
379
+ # 9. prepare stopping criteria
380
+ stopping_criteria = self._get_stopping_criteria(
381
+ generation_config=generation_config, stopping_criteria=stopping_criteria
382
+ )
383
+ # 10. go into different generation modes
384
+ if is_greedy_gen_mode:
385
+ if generation_config.num_return_sequences > 1:
386
+ raise ValueError(
387
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
388
+ " greedy search."
389
+ )
390
+
391
+ # 11. run greedy search
392
+ return self.greedy_search(
393
+ input_ids,
394
+ logits_processor=logits_processor,
395
+ stopping_criteria=stopping_criteria,
396
+ pad_token_id=generation_config.pad_token_id,
397
+ eos_token_id=generation_config.eos_token_id,
398
+ output_scores=generation_config.output_scores,
399
+ return_dict_in_generate=generation_config.return_dict_in_generate,
400
+ synced_gpus=synced_gpus,
401
+ **model_kwargs,
402
+ )
403
+
404
+ elif is_contrastive_search_gen_mode:
405
+ if generation_config.num_return_sequences > 1:
406
+ raise ValueError(
407
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
408
+ " contrastive search."
409
+ )
410
+
411
+ return self.contrastive_search(
412
+ input_ids,
413
+ top_k=generation_config.top_k,
414
+ penalty_alpha=generation_config.penalty_alpha,
415
+ logits_processor=logits_processor,
416
+ stopping_criteria=stopping_criteria,
417
+ pad_token_id=generation_config.pad_token_id,
418
+ eos_token_id=generation_config.eos_token_id,
419
+ output_scores=generation_config.output_scores,
420
+ return_dict_in_generate=generation_config.return_dict_in_generate,
421
+ synced_gpus=synced_gpus,
422
+ **model_kwargs,
423
+ )
424
+
425
+ elif is_sample_gen_mode:
426
+ # 11. prepare logits warper
427
+ logits_warper = self._get_logits_warper(generation_config)
428
+
429
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
430
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
431
+ input_ids=input_ids,
432
+ expand_size=generation_config.num_return_sequences,
433
+ is_encoder_decoder=self.config.is_encoder_decoder,
434
+ **model_kwargs,
435
+ )
436
+
437
+ # 13. run sample
438
+ return self.sample(
439
+ input_ids,
440
+ logits_processor=logits_processor,
441
+ logits_warper=logits_warper,
442
+ stopping_criteria=stopping_criteria,
443
+ pad_token_id=generation_config.pad_token_id,
444
+ eos_token_id=generation_config.eos_token_id,
445
+ output_scores=generation_config.output_scores,
446
+ return_dict_in_generate=generation_config.return_dict_in_generate,
447
+ synced_gpus=synced_gpus,
448
+ **model_kwargs,
449
+ )
450
+ elif is_sample_gen_stream_mode:
451
+ # 11. prepare logits warper
452
+ logits_warper = self._get_logits_warper(generation_config)
453
+
454
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
455
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
456
+ input_ids=input_ids,
457
+ expand_size=generation_config.num_return_sequences,
458
+ is_encoder_decoder=self.config.is_encoder_decoder,
459
+ **model_kwargs,
460
+ )
461
+
462
+ # 13. run sample
463
+ return self.sample_stream(
464
+ input_ids,
465
+ logits_processor=logits_processor,
466
+ logits_warper=logits_warper,
467
+ stopping_criteria=stopping_criteria,
468
+ pad_token_id=generation_config.pad_token_id,
469
+ eos_token_id=generation_config.eos_token_id,
470
+ output_scores=generation_config.output_scores,
471
+ return_dict_in_generate=generation_config.return_dict_in_generate,
472
+ synced_gpus=synced_gpus,
473
+ **model_kwargs,
474
+ )
475
+ elif is_beam_gen_mode:
476
+ if generation_config.num_return_sequences > generation_config.num_beams:
477
+ raise ValueError(
478
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
479
+ )
480
+
481
+ if stopping_criteria.max_length is None:
482
+ raise ValueError(
483
+ "`max_length` needs to be a stopping_criteria for now."
484
+ )
485
+
486
+ # 11. prepare beam search scorer
487
+ beam_scorer = BeamSearchScorer(
488
+ batch_size=batch_size,
489
+ num_beams=generation_config.num_beams,
490
+ device=inputs_tensor.device,
491
+ length_penalty=generation_config.length_penalty,
492
+ do_early_stopping=generation_config.early_stopping,
493
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
494
+ )
495
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
496
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
497
+ input_ids=input_ids,
498
+ expand_size=generation_config.num_beams,
499
+ is_encoder_decoder=self.config.is_encoder_decoder,
500
+ **model_kwargs,
501
+ )
502
+ # 13. run beam search
503
+ return self.beam_search(
504
+ input_ids,
505
+ beam_scorer,
506
+ logits_processor=logits_processor,
507
+ stopping_criteria=stopping_criteria,
508
+ pad_token_id=generation_config.pad_token_id,
509
+ eos_token_id=generation_config.eos_token_id,
510
+ output_scores=generation_config.output_scores,
511
+ return_dict_in_generate=generation_config.return_dict_in_generate,
512
+ synced_gpus=synced_gpus,
513
+ **model_kwargs,
514
+ )
515
+
516
+ elif is_beam_sample_gen_mode:
517
+ # 11. prepare logits warper
518
+ logits_warper = self._get_logits_warper(generation_config)
519
+
520
+ if stopping_criteria.max_length is None:
521
+ raise ValueError(
522
+ "`max_length` needs to be a stopping_criteria for now."
523
+ )
524
+ # 12. prepare beam search scorer
525
+ beam_scorer = BeamSearchScorer(
526
+ batch_size=batch_size * generation_config.num_return_sequences,
527
+ num_beams=generation_config.num_beams,
528
+ device=inputs_tensor.device,
529
+ length_penalty=generation_config.length_penalty,
530
+ do_early_stopping=generation_config.early_stopping,
531
+ )
532
+
533
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
534
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
535
+ input_ids=input_ids,
536
+ expand_size=generation_config.num_beams
537
+ * generation_config.num_return_sequences,
538
+ is_encoder_decoder=self.config.is_encoder_decoder,
539
+ **model_kwargs,
540
+ )
541
+
542
+ # 14. run beam sample
543
+ return self.beam_sample(
544
+ input_ids,
545
+ beam_scorer,
546
+ logits_processor=logits_processor,
547
+ logits_warper=logits_warper,
548
+ stopping_criteria=stopping_criteria,
549
+ pad_token_id=generation_config.pad_token_id,
550
+ eos_token_id=generation_config.eos_token_id,
551
+ output_scores=generation_config.output_scores,
552
+ return_dict_in_generate=generation_config.return_dict_in_generate,
553
+ synced_gpus=synced_gpus,
554
+ **model_kwargs,
555
+ )
556
+
557
+ elif is_group_beam_gen_mode:
558
+ if generation_config.num_return_sequences > generation_config.num_beams:
559
+ raise ValueError(
560
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
561
+ )
562
+
563
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
564
+ raise ValueError(
565
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
566
+ )
567
+
568
+ if stopping_criteria.max_length is None:
569
+ raise ValueError(
570
+ "`max_length` needs to be a stopping_criteria for now."
571
+ )
572
+
573
+ has_default_typical_p = (
574
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
575
+ )
576
+ if not has_default_typical_p:
577
+ raise ValueError(
578
+ "Decoder argument `typical_p` is not supported with beam groups."
579
+ )
580
+
581
+ # 11. prepare beam search scorer
582
+ beam_scorer = BeamSearchScorer(
583
+ batch_size=batch_size,
584
+ num_beams=generation_config.num_beams,
585
+ max_length=stopping_criteria.max_length,
586
+ device=inputs_tensor.device,
587
+ length_penalty=generation_config.length_penalty,
588
+ do_early_stopping=generation_config.early_stopping,
589
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
590
+ num_beam_groups=generation_config.num_beam_groups,
591
+ )
592
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
593
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
594
+ input_ids=input_ids,
595
+ expand_size=generation_config.num_beams,
596
+ is_encoder_decoder=self.config.is_encoder_decoder,
597
+ **model_kwargs,
598
+ )
599
+ # 13. run beam search
600
+ return self.group_beam_search(
601
+ input_ids,
602
+ beam_scorer,
603
+ logits_processor=logits_processor,
604
+ stopping_criteria=stopping_criteria,
605
+ pad_token_id=generation_config.pad_token_id,
606
+ eos_token_id=generation_config.eos_token_id,
607
+ output_scores=generation_config.output_scores,
608
+ return_dict_in_generate=generation_config.return_dict_in_generate,
609
+ synced_gpus=synced_gpus,
610
+ **model_kwargs,
611
+ )
612
+
613
+ elif is_constraint_gen_mode:
614
+ if generation_config.num_return_sequences > generation_config.num_beams:
615
+ raise ValueError(
616
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
617
+ )
618
+
619
+ if stopping_criteria.max_length is None:
620
+ raise ValueError(
621
+ "`max_length` needs to be a stopping_criteria for now."
622
+ )
623
+
624
+ if generation_config.num_beams <= 1:
625
+ raise ValueError(
626
+ "`num_beams` needs to be greater than 1 for constrained generation."
627
+ )
628
+
629
+ if generation_config.do_sample:
630
+ raise ValueError(
631
+ "`do_sample` needs to be false for constrained generation."
632
+ )
633
+
634
+ if (
635
+ generation_config.num_beam_groups is not None
636
+ and generation_config.num_beam_groups > 1
637
+ ):
638
+ raise ValueError(
639
+ "`num_beam_groups` not supported yet for constrained generation."
640
+ )
641
+
642
+ final_constraints = []
643
+ if generation_config.constraints is not None:
644
+ final_constraints = generation_config.constraints
645
+
646
+ if generation_config.force_words_ids is not None:
647
+
648
+ def typeerror():
649
+ raise ValueError(
650
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
651
+ f"of positive integers, but is {generation_config.force_words_ids}."
652
+ )
653
+
654
+ if (
655
+ not isinstance(generation_config.force_words_ids, list)
656
+ or len(generation_config.force_words_ids) == 0
657
+ ):
658
+ typeerror()
659
+
660
+ for word_ids in generation_config.force_words_ids:
661
+ if isinstance(word_ids[0], list):
662
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
663
+ typeerror()
664
+ if any(
665
+ not isinstance(token_ids, list) for token_ids in word_ids
666
+ ):
667
+ typeerror()
668
+ if any(
669
+ any(
670
+ (not isinstance(token_id, int) or token_id < 0)
671
+ for token_id in token_ids
672
+ )
673
+ for token_ids in word_ids
674
+ ):
675
+ typeerror()
676
+
677
+ constraint = DisjunctiveConstraint(word_ids)
678
+ else:
679
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
680
+ typeerror()
681
+ if any(
682
+ (not isinstance(token_id, int) or token_id < 0)
683
+ for token_id in word_ids
684
+ ):
685
+ typeerror()
686
+
687
+ constraint = PhrasalConstraint(word_ids)
688
+ final_constraints.append(constraint)
689
+
690
+ # 11. prepare beam search scorer
691
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
692
+ constraints=final_constraints,
693
+ batch_size=batch_size,
694
+ num_beams=generation_config.num_beams,
695
+ device=inputs_tensor.device,
696
+ length_penalty=generation_config.length_penalty,
697
+ do_early_stopping=generation_config.early_stopping,
698
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
699
+ )
700
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
701
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
702
+ input_ids=input_ids,
703
+ expand_size=generation_config.num_beams,
704
+ is_encoder_decoder=self.config.is_encoder_decoder,
705
+ **model_kwargs,
706
+ )
707
+ # 13. run beam search
708
+ return self.constrained_beam_search(
709
+ input_ids,
710
+ constrained_beam_scorer=constrained_beam_scorer,
711
+ logits_processor=logits_processor,
712
+ stopping_criteria=stopping_criteria,
713
+ pad_token_id=generation_config.pad_token_id,
714
+ eos_token_id=generation_config.eos_token_id,
715
+ output_scores=generation_config.output_scores,
716
+ return_dict_in_generate=generation_config.return_dict_in_generate,
717
+ synced_gpus=synced_gpus,
718
+ **model_kwargs,
719
+ )
720
+
721
+ @torch.no_grad()
722
+ def sample_stream(
723
+ self,
724
+ input_ids: torch.LongTensor,
725
+ logits_processor: Optional[LogitsProcessorList] = None,
726
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
727
+ logits_warper: Optional[LogitsProcessorList] = None,
728
+ max_length: Optional[int] = None,
729
+ pad_token_id: Optional[int] = None,
730
+ eos_token_id: Optional[Union[int, List[int]]] = None,
731
+ output_attentions: Optional[bool] = None,
732
+ output_hidden_states: Optional[bool] = None,
733
+ output_scores: Optional[bool] = None,
734
+ return_dict_in_generate: Optional[bool] = None,
735
+ synced_gpus: Optional[bool] = False,
736
+ **model_kwargs,
737
+ ) -> Union[SampleOutput, torch.LongTensor]:
738
+ r"""
739
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
740
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
741
+
742
+ <Tip warning={true}>
743
+
744
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
745
+ For an overview of generation strategies and code examples, check the [following
746
+ guide](./generation_strategies).
747
+
748
+ </Tip>
749
+
750
+ Parameters:
751
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
752
+ The sequence used as a prompt for the generation.
753
+ logits_processor (`LogitsProcessorList`, *optional*):
754
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
755
+ used to modify the prediction scores of the language modeling head applied at each generation step.
756
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
757
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
758
+ used to tell if the generation loop should stop.
759
+ logits_warper (`LogitsProcessorList`, *optional*):
760
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
761
+ to warp the prediction score distribution of the language modeling head applied before multinomial
762
+ sampling at each generation step.
763
+ max_length (`int`, *optional*, defaults to 20):
764
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
765
+ tokens. The maximum length of the sequence to be generated.
766
+ pad_token_id (`int`, *optional*):
767
+ The id of the *padding* token.
768
+ eos_token_id (`int`, *optional*):
769
+ The id of the *end-of-sequence* token.
770
+ output_attentions (`bool`, *optional*, defaults to `False`):
771
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
772
+ returned tensors for more details.
773
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
774
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
775
+ for more details.
776
+ output_scores (`bool`, *optional*, defaults to `False`):
777
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
778
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
779
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
780
+ synced_gpus (`bool`, *optional*, defaults to `False`):
781
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
782
+ model_kwargs:
783
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
784
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
785
+
786
+ Return:
787
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
788
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
789
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
790
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
791
+ `model.config.is_encoder_decoder=True`.
792
+
793
+ Examples:
794
+
795
+ ```python
796
+ >>> from transformers import (
797
+ ... AutoTokenizer,
798
+ ... AutoModelForCausalLM,
799
+ ... LogitsProcessorList,
800
+ ... MinLengthLogitsProcessor,
801
+ ... TopKLogitsWarper,
802
+ ... TemperatureLogitsWarper,
803
+ ... StoppingCriteriaList,
804
+ ... MaxLengthCriteria,
805
+ ... )
806
+ >>> import torch
807
+
808
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
809
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
810
+
811
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
812
+ >>> model.config.pad_token_id = model.config.eos_token_id
813
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
814
+
815
+ >>> input_prompt = "Today is a beautiful day, and"
816
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
817
+
818
+ >>> # instantiate logits processors
819
+ >>> logits_processor = LogitsProcessorList(
820
+ ... [
821
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
822
+ ... ]
823
+ ... )
824
+ >>> # instantiate logits processors
825
+ >>> logits_warper = LogitsProcessorList(
826
+ ... [
827
+ ... TopKLogitsWarper(50),
828
+ ... TemperatureLogitsWarper(0.7),
829
+ ... ]
830
+ ... )
831
+
832
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
833
+
834
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
835
+ >>> outputs = model.sample(
836
+ ... input_ids,
837
+ ... logits_processor=logits_processor,
838
+ ... logits_warper=logits_warper,
839
+ ... stopping_criteria=stopping_criteria,
840
+ ... )
841
+
842
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
843
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
844
+ ```"""
845
+ # init values
846
+ logits_processor = (
847
+ logits_processor if logits_processor is not None else LogitsProcessorList()
848
+ )
849
+ stopping_criteria = (
850
+ stopping_criteria
851
+ if stopping_criteria is not None
852
+ else StoppingCriteriaList()
853
+ )
854
+ if max_length is not None:
855
+ warnings.warn(
856
+ "`max_length` is deprecated in this function, use"
857
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
858
+ UserWarning,
859
+ )
860
+ stopping_criteria = validate_stopping_criteria(
861
+ stopping_criteria, max_length
862
+ )
863
+ logits_warper = (
864
+ logits_warper if logits_warper is not None else LogitsProcessorList()
865
+ )
866
+ pad_token_id = (
867
+ pad_token_id
868
+ if pad_token_id is not None
869
+ else self.generation_config.pad_token_id
870
+ )
871
+ eos_token_id = (
872
+ eos_token_id
873
+ if eos_token_id is not None
874
+ else self.generation_config.eos_token_id
875
+ )
876
+ if isinstance(eos_token_id, int):
877
+ eos_token_id = [eos_token_id]
878
+ output_scores = (
879
+ output_scores
880
+ if output_scores is not None
881
+ else self.generation_config.output_scores
882
+ )
883
+ output_attentions = (
884
+ output_attentions
885
+ if output_attentions is not None
886
+ else self.generation_config.output_attentions
887
+ )
888
+ output_hidden_states = (
889
+ output_hidden_states
890
+ if output_hidden_states is not None
891
+ else self.generation_config.output_hidden_states
892
+ )
893
+ return_dict_in_generate = (
894
+ return_dict_in_generate
895
+ if return_dict_in_generate is not None
896
+ else self.generation_config.return_dict_in_generate
897
+ )
898
+
899
+ # init attention / hidden states / scores tuples
900
+ scores = () if (return_dict_in_generate and output_scores) else None
901
+ decoder_attentions = (
902
+ () if (return_dict_in_generate and output_attentions) else None
903
+ )
904
+ cross_attentions = (
905
+ () if (return_dict_in_generate and output_attentions) else None
906
+ )
907
+ decoder_hidden_states = (
908
+ () if (return_dict_in_generate and output_hidden_states) else None
909
+ )
910
+
911
+ # keep track of which sequences are already finished
912
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
913
+
914
+ this_peer_finished = False # used by synced_gpus only
915
+ # auto-regressive generation
916
+ while True:
917
+ if synced_gpus:
918
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
919
+ # The following logic allows an early break if all peers finished generating their sequence
920
+ this_peer_finished_flag = torch.tensor(
921
+ 0.0 if this_peer_finished else 1.0
922
+ ).to(input_ids.device)
923
+ # send 0.0 if we finished, 1.0 otherwise
924
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
925
+ # did all peers finish? the reduced sum will be 0.0 then
926
+ if this_peer_finished_flag.item() == 0.0:
927
+ break
928
+
929
+ # prepare model inputs
930
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
931
+
932
+ # forward pass to get next token
933
+ outputs = self(
934
+ **model_inputs,
935
+ return_dict=True,
936
+ output_attentions=output_attentions,
937
+ output_hidden_states=output_hidden_states,
938
+ )
939
+
940
+ if synced_gpus and this_peer_finished:
941
+ continue # don't waste resources running the code we don't need
942
+
943
+ next_token_logits = outputs.logits[:, -1, :]
944
+
945
+ # pre-process distribution
946
+ next_token_scores = logits_processor(input_ids, next_token_logits)
947
+ next_token_scores = logits_warper(input_ids, next_token_scores)
948
+
949
+ # Store scores, attentions and hidden_states when required
950
+ if return_dict_in_generate:
951
+ if output_scores:
952
+ scores += (next_token_scores,)
953
+ if output_attentions:
954
+ decoder_attentions += (
955
+ (outputs.decoder_attentions,)
956
+ if self.config.is_encoder_decoder
957
+ else (outputs.attentions,)
958
+ )
959
+ if self.config.is_encoder_decoder:
960
+ cross_attentions += (outputs.cross_attentions,)
961
+
962
+ if output_hidden_states:
963
+ decoder_hidden_states += (
964
+ (outputs.decoder_hidden_states,)
965
+ if self.config.is_encoder_decoder
966
+ else (outputs.hidden_states,)
967
+ )
968
+
969
+ # sample
970
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
971
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
972
+
973
+ # finished sentences should have their next token be a padding token
974
+ if eos_token_id is not None:
975
+ if pad_token_id is None:
976
+ raise ValueError(
977
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
978
+ )
979
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
980
+ 1 - unfinished_sequences
981
+ )
982
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
983
+ # update generated ids, model inputs, and length for next step
984
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
985
+ model_kwargs = self._update_model_kwargs_for_generation(
986
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
987
+ )
988
+
989
+ # if eos_token was found in one sentence, set sentence to finished
990
+ if eos_token_id is not None:
991
+ unfinished_sequences = unfinished_sequences.mul(
992
+ (sum(next_tokens != i for i in eos_token_id)).long()
993
+ )
994
+
995
+ # stop when each sentence is finished, or if we exceed the maximum length
996
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
997
+ if not synced_gpus:
998
+ break
999
+ else:
1000
+ this_peer_finished = True
1001
+
1002
+
1003
+ def init_stream_support():
1004
+ """Overload PreTrainedModel for streaming."""
1005
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
1006
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ from transformers import PreTrainedModel
1011
+ from transformers import AutoTokenizer, AutoModelForCausalLM
1012
+
1013
+ PreTrainedModel.generate = NewGenerationMixin.generate
1014
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1015
+ model = AutoModelForCausalLM.from_pretrained(
1016
+ "bigscience/bloom-560m", torch_dtype=torch.float16
1017
+ )
1018
+
1019
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
1020
+ model = model.to("cuda:0")
1021
+ model = model.eval()
1022
+ prompt_text = "hello? \n"
1023
+ input_ids = tokenizer(
1024
+ prompt_text, return_tensors="pt", add_special_tokens=False
1025
+ ).input_ids
1026
+ input_ids = input_ids.to("cuda:0")
1027
+
1028
+ with torch.no_grad():
1029
+ result = model.generate(
1030
+ input_ids,
1031
+ max_new_tokens=200,
1032
+ do_sample=True,
1033
+ top_k=30,
1034
+ top_p=0.85,
1035
+ temperature=0.35,
1036
+ repetition_penalty=1.2,
1037
+ early_stopping=True,
1038
+ seed=0,
1039
+ )
1040
+ print(tokenizer.decode(result, skip_special_tokens=True))
1041
+ generator = model.generate(
1042
+ input_ids,
1043
+ max_new_tokens=200,
1044
+ do_sample=True,
1045
+ top_k=30,
1046
+ top_p=0.85,
1047
+ temperature=0.35,
1048
+ repetition_penalty=1.2,
1049
+ early_stopping=True,
1050
+ seed=0,
1051
+ do_stream=True,
1052
+ )
1053
+ stream_result = ""
1054
+ for x in generator:
1055
+ chunk = tokenizer.decode(x, skip_special_tokens=True)
1056
+ stream_result += chunk
1057
+ print(stream_result)