NeuralInternet's picture
Duplicate from chansung/Alpaca-LoRA-Serve
3ac5300
import gc
import copy
from tenacity import RetryError
from tenacity import retry, stop_after_attempt, wait_fixed
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
LogitsProcessorList,
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
)
def get_output_batch(
model, tokenizer, prompts, generation_config
):
if len(prompts) == 1:
encoding = tokenizer(prompts, return_tensors="pt")
input_ids = encoding["input_ids"].cuda()
generated_id = model.generate(
input_ids=input_ids,
generation_config=generation_config,
max_new_tokens=256
)
decoded = tokenizer.batch_decode(generated_id)
del input_ids, generated_id
torch.cuda.empty_cache()
return decoded
else:
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
generated_ids = model.generate(
**encodings,
generation_config=generation_config,
max_new_tokens=256
)
decoded = tokenizer.batch_decode(generated_ids)
del encodings, generated_ids
torch.cuda.empty_cache()
return decoded
# StreamModel is borrowed from basaran project
# please find more info about it -> https://github.com/hyperonym/basaran
class StreamModel:
"""StreamModel wraps around a language model to provide stream decoding."""
def __init__(self, model, tokenizer):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(
self,
prompt,
min_tokens=0,
max_tokens=16,
temperature=1.0,
top_p=1.0,
n=1,
logprobs=0,
):
"""Create a completion stream for the provided prompt."""
input_ids = self.tokenize(prompt)
logprobs = max(logprobs, 0)
# bigger than 1
chunk_size = 2
chunk_count = 0
# Generate completion tokens.
final_tokens = torch.empty(0).to(self.device)
try:
for tokens in self.generate(
input_ids[None, :].repeat(n, 1),
logprobs=logprobs,
min_new_tokens=min_tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
if chunk_count < chunk_size:
chunk_count = chunk_count + 1
final_tokens = torch.cat((final_tokens, tokens))
if chunk_count == chunk_size-1:
chunk_count = 0
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
if chunk_count > 0:
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
except RetryError as e:
print(e)
del input_ids
gc.collect()
del final_tokens
if self.device == "cuda":
torch.cuda.empty_cache()
@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
def _infer(self, model_fn, **kwargs):
"""Call a model function in inference mode with auto retrying."""
# This is a temporary workaround for bitsandbytes #162:
# https://github.com/TimDettmers/bitsandbytes/issues/162
with torch.inference_mode():
return model_fn(**kwargs)
def _logits_processor(self, config, input_length):
"""Set up logits processor based on the generation config."""
processor = LogitsProcessorList()
# Add processor for enforcing a min-length of new tokens.
if (
config.min_new_tokens is not None
and config.min_new_tokens > 0
and config.eos_token_id is not None
):
processor.append(
MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_length,
min_new_tokens=config.min_new_tokens,
eos_token_id=config.eos_token_id,
)
)
# Add processor for scaling output probability distribution.
if (
config.temperature is not None
and config.temperature > 0
and config.temperature != 1.0
):
processor.append(TemperatureLogitsWarper(config.temperature))
# Add processor for nucleus sampling.
if config.top_p is not None and config.top_p > 0 and config.top_p < 1:
processor.append(TopPLogitsWarper(config.top_p))
return processor
def tokenize(self, text):
"""Tokenize a string into a tensor of token IDs."""
batch = self.tokenizer.encode(text, return_tensors="pt")
return batch[0].to(self.device)
def generate(self, input_ids, logprobs=0, **kwargs):
"""Generate a stream of predicted tokens using the language model."""
# Store the original batch size and input length.
batch_size = input_ids.shape[0]
input_length = input_ids.shape[-1]
# Separate model arguments from generation config.
config = self.model.generation_config
config = copy.deepcopy(config)
kwargs = config.update(**kwargs)
kwargs["output_attentions"] = False
kwargs["output_hidden_states"] = False
kwargs["use_cache"] = True # config.use_cache
# Collect special token IDs.
pad_token_id = config.pad_token_id
bos_token_id = config.bos_token_id
eos_token_id = config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id[0]
# Generate from eos if no input is specified.
if input_length == 0:
input_ids = input_ids.new_ones((batch_size, 1)).long()
if eos_token_id is not None:
input_ids = input_ids * eos_token_id[0]
input_length = 1
# Prepare inputs for encoder-decoder models.
if self.model.config.is_encoder_decoder:
# Get outputs from the encoder.
encoder = self.model.get_encoder()
encoder_kwargs = kwargs.copy()
encoder_kwargs.pop("use_cache", None)
encoder_kwargs["input_ids"] = input_ids
encoder_kwargs["return_dict"] = True
encoder_outputs = self._infer(encoder, **encoder_kwargs)
kwargs["encoder_outputs"] = encoder_outputs
# Reinitialize inputs for the decoder.
decoder_start_token_id = config.decoder_start_token_id
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
input_ids = input_ids.new_ones((batch_size, 1))
input_ids = input_ids * decoder_start_token_id
input_length = 1
# Set up logits processor.
processor = self._logits_processor(config, input_length)
# Keep track of which sequences are already finished.
unfinished = input_ids.new_ones(batch_size)
# Start auto-regressive generation.
while True:
inputs = self.model.prepare_inputs_for_generation(
input_ids, **kwargs
) # noqa: E501
outputs = self._infer(
self.model,
**inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# Pre-process the probability distribution of the next tokens.
logits = outputs.logits[:, -1, :]
with torch.inference_mode():
logits = processor(input_ids, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
# Select deterministic or stochastic decoding strategy.
if (config.top_p is not None and config.top_p <= 0) or (
config.temperature is not None and config.temperature <= 0
):
tokens = torch.argmax(probs, dim=-1)[:, None]
else:
tokens = torch.multinomial(probs, num_samples=1)
tokens = tokens.squeeze(1)
# Finished sequences should have their next token be a padding.
if pad_token_id is not None:
tokens = tokens * unfinished + pad_token_id * (1 - unfinished)
# Append selected tokens to the inputs.
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)
# Mark sequences with eos tokens as finished.
if eos_token_id is not None:
not_eos = sum(tokens != i for i in eos_token_id)
unfinished = unfinished.mul(not_eos.long())
# Set status to -1 if exceeded the max length.
status = unfinished.clone()
if input_ids.shape[-1] - input_length >= config.max_new_tokens:
status = 0 - status
# Yield predictions and status.
yield tokens
# Stop when finished or exceeded the max length.
if status.max() <= 0:
break