Spaces:
Runtime error
Runtime error
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | |
from typing import Any, Literal, Optional | |
import torch | |
# import torch._dynamo.config | |
# import torch._inductor.config | |
from litgpt.model import GPT | |
from utils.snac_utils import layershift, snac_config | |
from tqdm import tqdm | |
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: | |
if torch._dynamo.is_compiling(): | |
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly | |
distribution = torch.empty_like(probs).exponential_(1) | |
return torch.argmax(probs / distribution, dim=-1, keepdim=True) | |
return torch.multinomial(probs, num_samples=1) | |
def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=False) | |
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
# Example: | |
# sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0] | |
# sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7 | |
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | |
# Keep at least 1 token always to prevent the case where no token is selected | |
# In this case the most probable one is always kept | |
sorted_indices_to_remove[-1:] = 0 | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
0, sorted_indices, sorted_indices_to_remove | |
) | |
logits = logits.masked_fill(indices_to_remove, float("-inf")) | |
return logits | |
def sample( | |
logits: torch.Tensor, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
) -> torch.Tensor: | |
if top_p < 0.0 or top_p > 1.0: | |
raise ValueError(f"top_p must be in [0, 1], got {top_p}") | |
logits = logits[0, -1] | |
# optionally crop the logits to only the top k options | |
if top_k is not None: | |
v, i = torch.topk(logits, min(top_k, logits.size(-1))) | |
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions | |
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) | |
# optionally scale the logits and sample from a probability distribution | |
if temperature > 0.0 or top_p > 0.0: | |
if temperature > 0.0: | |
logits = logits / temperature | |
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p | |
if top_p < 1.0: | |
logits = sample_top_p(logits, top_p) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return multinomial_num_samples_1(probs) | |
return torch.argmax(logits, dim=-1, keepdim=True) | |
def next_token( | |
model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any | |
) -> torch.Tensor: | |
input_pos = input_pos.to(model.device) | |
logits_a, logit_t = model(x, input_pos) | |
next_audio_tokens = [] | |
for logit_a in logits_a: | |
next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype) | |
next_audio_tokens.append(next_a) | |
next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype) | |
return next_audio_tokens, next_t | |
def next_token_asr( | |
model: GPT, | |
input_pos: torch.Tensor, | |
audio_features: torch.tensor, | |
lens: int, | |
input_ids: list, | |
**kwargs: Any, | |
) -> torch.Tensor: | |
input_pos = input_pos.to(model.device) | |
input_ids = [input_id.to(model.device) for input_id in input_ids] | |
logits_a, logit_t = model(audio_features, input_ids, input_pos, whisper_lens=lens) | |
next_audio_tokens = [] | |
for logit_a in logits_a: | |
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) | |
next_audio_tokens.append(next_a) | |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) | |
return next_audio_tokens, next_t | |
def next_token_A1T2( | |
model: GPT, | |
audio_features: torch.tensor, | |
input_ids: list, | |
whisper_lens: int, | |
task: list, | |
input_pos: torch.Tensor, | |
**kwargs: Any, | |
) -> torch.Tensor: | |
input_pos = input_pos.to(model.device) | |
input_ids = [input_id.to(model.device) for input_id in input_ids] | |
logits_a, logit_t = model( | |
audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task | |
) | |
next_audio_tokens = [] | |
for logit_a in logits_a: | |
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) | |
next_audio_tokens.append(next_a) | |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) | |
return next_audio_tokens, next_t | |
def next_token_A1T1( | |
model: GPT, | |
audio_features: torch.tensor, | |
input_ids: list, | |
whisper_lens: int, | |
task: list, | |
input_pos: torch.Tensor, | |
**kwargs: Any, | |
) -> torch.Tensor: | |
input_pos = input_pos.to(model.device) | |
input_ids = [input_id.to(model.device) for input_id in input_ids] | |
logits_a, logit_t = model( | |
audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task | |
) | |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) | |
return next_t | |
def next_token_batch( | |
model: GPT, | |
audio_features: torch.tensor, | |
input_ids: list, | |
whisper_lens: int, | |
task: list, | |
input_pos: torch.Tensor, | |
**kwargs: Any, | |
) -> torch.Tensor: | |
input_pos = input_pos.to(model.device) | |
input_ids = [input_id.to(model.device) for input_id in input_ids] | |
logits_a, logit_t = model( | |
audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task | |
) | |
for i in range(7): | |
logits_a[i] = logits_a[i][0].unsqueeze(0) | |
logit_t = logit_t[1].unsqueeze(0) | |
next_audio_tokens = [] | |
for logit_a in logits_a: | |
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) | |
next_audio_tokens.append(next_a) | |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) | |
return next_audio_tokens, next_t | |
# torch._dynamo.config.automatic_dynamic_shapes = True | |
# torch._inductor.config.triton.unique_kernel_names = True | |
# torch._inductor.config.coordinate_descent_tuning = True | |
# next_token = torch.compile(next_token, mode="reduce-overhead") | |
def generate( | |
model: GPT, | |
input_ids: list, | |
max_returned_tokens: int, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
# print("eos_id_a:", eos_id_a) | |
# print("eos_id_t:", eos_id_t) | |
# print("pad_id:", pad_id) | |
""" | |
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | |
The implementation of this function is modified from A. Karpathy's nanoGPT. | |
Args: | |
model: The model to use. | |
prompt: Tensor of shape (T) with indices of the prompt sequence. | |
max_returned_tokens: The maximum number of tokens to return (given plus generated). | |
temperature: Scales the predicted logits by 1 / temperature. | |
top_k: If specified, only sample among the tokens with the k highest probabilities. | |
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. | |
In top-p sampling, the next token is sampled from the highest probability tokens | |
whose cumulative probability exceeds the threshold `top_p`. When specified, | |
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent | |
to sampling the most probable token, while `top_p=1` samples from the whole distribution. | |
It can be used in conjunction with `top_k` and `temperature` with the following order | |
of application: | |
1. `top_k` sampling | |
2. `temperature` scaling | |
3. `top_p` sampling | |
For more details, see https://arxiv.org/abs/1904.09751 | |
or https://huyenchip.com/2024/01/16/sampling.html#top_p | |
eos_id: If specified, stop generating any more token once the <eos> token is triggered. | |
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output. | |
""" | |
T = input_ids[0].size(0) | |
device = input_ids[0].device | |
assert max_returned_tokens > T | |
if model.max_seq_length < max_returned_tokens - 1: | |
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a | |
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do | |
# not support it to avoid negatively impacting the overall speed | |
raise NotImplementedError( | |
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" | |
) | |
for input_id in input_ids: | |
input_id = [input_id] | |
( | |
tokens_A1, | |
tokens_A2, | |
tokens_A3, | |
tokens_A4, | |
tokens_A5, | |
tokens_A6, | |
tokens_A7, | |
tokens_T, | |
) = input_ids | |
tokens_A1_output = [tokens_A1] | |
tokens_A2_output = [tokens_A2] | |
tokens_A3_output = [tokens_A3] | |
tokens_A4_output = [tokens_A4] | |
tokens_A5_output = [tokens_A5] | |
tokens_A6_output = [tokens_A6] | |
tokens_A7_output = [tokens_A7] | |
tokens_T_output = [tokens_T] | |
list_output = [ | |
tokens_A1_output, | |
tokens_A2_output, | |
tokens_A3_output, | |
tokens_A4_output, | |
tokens_A5_output, | |
tokens_A6_output, | |
tokens_A7_output, | |
tokens_T_output, | |
] | |
input_pos = torch.tensor([T], device=device) | |
model_input_ids = [ | |
tokens_A1.view(1, -1), | |
tokens_A2.view(1, -1), | |
tokens_A3.view(1, -1), | |
tokens_A4.view(1, -1), | |
tokens_A5.view(1, -1), | |
tokens_A6.view(1, -1), | |
tokens_A7.view(1, -1), | |
tokens_T.view(1, -1), | |
] | |
tokens_A, token_T = next_token( | |
model, | |
torch.arange(0, T, device=device), | |
model_input_ids, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
for i in range(7): | |
list_output[i].append(tokens_A[i].clone()) | |
list_output[7].append(token_T.clone()) | |
# prepare the input for the next iteration | |
for i in range(7): | |
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size | |
token_T = token_T.clone() | |
text_end = False | |
max_returned_tokens = 1000 | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [ | |
token_a.view(1, -1).to(torch.int32) for token_a in tokens_A | |
] + [token_T.view(1, -1).to(torch.int32)] | |
tokens_A, token_T = next_token( | |
model, | |
input_pos, | |
model_input_ids, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if text_end: | |
token_T = torch.tensor([pad_id], device=device) | |
for i in range(7): | |
list_output[i].append(tokens_A[i].clone()) | |
list_output[7].append(token_T.clone()) | |
if tokens_A[-1] == eos_id_a: | |
break | |
if token_T == eos_id_t: | |
if generate_text: | |
break | |
text_end = True | |
for i in range(7): | |
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size | |
token_T = token_T.clone() | |
input_pos = input_pos.add_(1) | |
for i in range(len(list_output)): | |
list_output[i] = torch.cat(list_output[i]) | |
return list_output | |
def generate_TA_BATCH( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 1000, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
assert max_returned_tokens > T | |
if model.max_seq_length < max_returned_tokens - 1: | |
raise NotImplementedError( | |
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" | |
) | |
input_pos = torch.tensor([T], device=device) | |
model_input_ids = input_ids | |
list_output = [[] for i in range(8)] | |
tokens_A, token_T = next_token_batch( | |
model, | |
audio_features.to(torch.float32).to(model.device), | |
input_ids, | |
[T - 3, T - 3], | |
["A1T2", "A1T2"], | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
for i in range(7): | |
list_output[i].append(tokens_A[i].tolist()[0]) | |
list_output[7].append(token_T.tolist()[0]) | |
model_input_ids = [[] for i in range(8)] | |
for i in range(7): | |
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size | |
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) | |
model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)) | |
model_input_ids[i] = torch.stack(model_input_ids[i]) | |
model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
model_input_ids[-1] = torch.stack(model_input_ids[-1]) | |
text_end = False | |
for _ in range(2, max_returned_tokens - T + 1): | |
tokens_A, token_T = next_token_batch( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if text_end: | |
token_T = torch.tensor([pad_id_t], device=device) | |
if tokens_A[-1] == eos_id_a: | |
break | |
if token_T == eos_id_t: | |
text_end = True | |
for i in range(7): | |
list_output[i].append(tokens_A[i].tolist()[0]) | |
list_output[7].append(token_T.tolist()[0]) | |
model_input_ids = [[] for i in range(8)] | |
for i in range(7): | |
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size | |
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) | |
model_input_ids[i].append( | |
torch.tensor([layershift(snac_config.end_of_audio, i)], device=device) | |
) | |
model_input_ids[i] = torch.stack(model_input_ids[i]) | |
model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
model_input_ids[-1] = torch.stack(model_input_ids[-1]) | |
input_pos = input_pos.add_(1) | |
return list_output | |
def generate_TT( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 2048, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
output = [] | |
token_T = next_token_A1T1( | |
model, | |
None, | |
input_ids, | |
None, | |
None, | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = torch.tensor([T], device=device) | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [] | |
for i in range(7): | |
model_input_ids.append( | |
torch.tensor([layershift(snac_config.end_of_audio, i)]) | |
.view(1, -1) | |
.to(torch.int32) | |
.to(device) | |
) | |
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) | |
token_T = next_token_A1T1( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if token_T == eos_id_t: | |
break | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = input_pos.add_(1) | |
return output | |
def generate_AT( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 2048, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
output = [] | |
token_T = next_token_A1T1( | |
model, | |
audio_features.to(torch.float32).to(model.device), | |
input_ids, | |
[T - 3], | |
["AT"], | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = torch.tensor([T], device=device) | |
text_end = False | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [] | |
for i in range(7): | |
model_input_ids.append( | |
torch.tensor([layershift(snac_config.end_of_audio, i)]) | |
.view(1, -1) | |
.to(torch.int32) | |
.to(device) | |
) | |
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) | |
token_T = next_token_A1T1( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if token_T == eos_id_t: | |
break | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = input_pos.add_(1) | |
return output | |
def generate_TA( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 2048, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
output = [[] for _ in range(8)] | |
tokens_A, token_T = next_token_A1T2( | |
model, | |
None, | |
input_ids, | |
None, | |
None, | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
for i in range(7): | |
output[i].append(tokens_A[i].clone().tolist()[0]) | |
output[7].append(token_T.clone().tolist()[0]) | |
input_pos = torch.tensor([T], device=device) | |
text_end = False | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [] | |
for i in range(7): | |
model_input_ids.append( | |
layershift(tokens_A[i].clone(), i) | |
.view(1, -1) | |
.to(torch.int32) | |
.to(device) | |
) | |
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) | |
tokens_A, token_T = next_token_A1T2( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if text_end: | |
token_T = torch.tensor([pad_id_t], device=device) | |
if tokens_A[-1] == eos_id_a: | |
break | |
if token_T == eos_id_t: | |
text_end = True | |
for i in range(7): | |
output[i].append(tokens_A[i].clone().tolist()[0]) | |
output[7].append(token_T.clone().tolist()[0]) | |
input_pos = input_pos.add_(1) | |
return output | |
def generate_AA( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 2048, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
output = [[] for _ in range(8)] | |
tokens_A, token_T = next_token_A1T2( | |
model, | |
audio_features.to(torch.float32).to(model.device), | |
input_ids, | |
[T - 3], | |
["A1T2"], | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
for i in range(7): | |
output[i].append(tokens_A[i].clone().tolist()[0]) | |
output[7].append(token_T.clone().tolist()[0]) | |
input_pos = torch.tensor([T], device=device) | |
text_end = False | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [] | |
for i in range(7): | |
model_input_ids.append( | |
layershift(tokens_A[i].clone(), i) | |
.view(1, -1) | |
.to(torch.int32) | |
.to(device) | |
) | |
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) | |
tokens_A, token_T = next_token_A1T2( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if text_end: | |
token_T = torch.tensor([pad_id_t], device=device) | |
if tokens_A[-1] == eos_id_a: | |
break | |
if token_T == eos_id_t: | |
# print("text_end") | |
text_end = True | |
for i in range(7): | |
output[i].append(tokens_A[i].clone().tolist()[0]) | |
output[7].append(token_T.clone().tolist()[0]) | |
input_pos = input_pos.add_(1) | |
return output | |
def generate_ASR( | |
model: GPT, | |
audio_features: torch.Tensor, | |
input_ids: list, | |
leng, | |
task, | |
max_returned_tokens: int = 1200, | |
*, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: float = 1.0, | |
eos_id_a: Optional[int] = None, | |
eos_id_t: Optional[int] = None, | |
pad_id_t: Optional[int] = None, | |
shift: Optional[int] = None, | |
include_prompt: bool = True, | |
generate_text=False, | |
) -> torch.Tensor: | |
T = input_ids[0].size(1) | |
device = input_ids[0].device | |
output = [] | |
token_T = next_token_A1T1( | |
model, | |
audio_features.to(torch.float32).to(model.device), | |
input_ids, | |
[T - 3], | |
["asr"], | |
input_pos=torch.arange(0, T, device=device), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = torch.tensor([T], device=device) | |
text_end = False | |
for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
model_input_ids = [] | |
for i in range(7): | |
model_input_ids.append( | |
torch.tensor([layershift(snac_config.end_of_audio, i)]) | |
.view(1, -1) | |
.to(torch.int32) | |
.to(device) | |
) | |
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device)) | |
token_T = next_token_A1T1( | |
model, | |
None, | |
model_input_ids, | |
None, | |
None, | |
input_pos=input_pos, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
if token_T == eos_id_t: | |
break | |
output.append(token_T.clone().tolist()[0]) | |
input_pos = input_pos.add_(1) | |
return output | |