|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
class TinyLlama: |
|
def __init__(self) -> None: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
load_in_4bit=True, |
|
device_map="auto", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
print(f"LLM loaded to {self.model.device}") |
|
|
|
self._messages = [] |
|
|
|
def __call__(self, messages, *args, **kwds): |
|
tokenized_chat = self.tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
inputs = self.tokenizer(tokenized_chat, return_tensors="pt").to( |
|
self.model.device |
|
) |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
use_cache=True, |
|
max_length=1000, |
|
min_length=10, |
|
temperature=0.7, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
) |
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return generated_text |
|
|