Chris-lab / utils /model.py
kz209
update
162b68f
raw
history blame
3.98 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
from huggingface_hub import login
import os
import logging
login(token = os.getenv('HF_TOKEN'))
class Model(torch.nn.Module):
number_of_models = 0
__model_list__ = [
"Qwen/Qwen2-1.5B-Instruct",
"lmsys/vicuna-7b-v1.5",
"google-t5/t5-large",
"mistralai/Mistral-7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3.1-8B-Instruct"
]
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
super(Model, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.name = model_name
logging.info(f'start loading model {self.name}')
if model_name == "google-t5/t5-large":
# For T5 or any other Seq2Seq model
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
else:
# For GPT-like models or other causal language models
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
logging.info(f'Loaded model {self.name}')
self.model.eval()
self.update()
@classmethod
def update(cls):
cls.number_of_models += 1
def return_mode_name(self):
return self.name
def return_tokenizer(self):
return self.tokenizer
def return_model(self):
return self.pipeline
def gen(self, content_list, temp=0.001, max_length=500, streaming=False):
# Convert list of texts to input IDs
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
if streaming:
# Set up the initial generation parameters
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"temperature": temp,
"eos_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 1, # Generate one token at a time
"return_dict_in_generate": True,
"output_scores": True
}
# Generate and yield tokens one by one
generated_tokens = 0
batch_size = input_ids.shape[0]
active_sequences = torch.arange(batch_size)
while generated_tokens < max_length and len(active_sequences) > 0:
with torch.no_grad():
output = self.model.generate(**gen_kwargs)
next_tokens = output.sequences[:, -1].unsqueeze(-1)
# Yield the newly generated tokens for each sequence in the batch
for i, token in zip(active_sequences, next_tokens):
yield i, self.tokenizer.decode(token[0], skip_special_tokens=True)
# Update input_ids for the next iteration
gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
generated_tokens += 1
# Check for completed sequences
completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
if len(active_sequences) > 0:
gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
else:
# Non-streaming generation (unchanged)
outputs = self.model.generate(
input_ids,
max_new_tokens=max_length,
do_sample=True,
temperature=temp,
eos_token_id=self.tokenizer.eos_token_id,
)
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)