d-RAG / llms /tiny_llama.py
rajdeep1337
Initial commit
8944d2f
raw
history blame
1.22 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class TinyLlama:
def __init__(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)
self.model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
load_in_4bit=True,
device_map="auto",
bnb_4bit_compute_dtype=torch.float16,
)
print(f"LLM loaded to {self.model.device}")
self._messages = []
def __call__(self, messages, *args, **kwds):
tokenized_chat = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(tokenized_chat, return_tensors="pt").to(
self.model.device
)
outputs = self.model.generate(
**inputs,
use_cache=True,
max_length=1000,
min_length=10,
temperature=0.7,
num_return_sequences=1,
do_sample=True,
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text