Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer | |
from threading import Thread | |
import gradio as gr | |
from peft import PeftModel | |
model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base" | |
peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI" | |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto") | |
# tokenizer.chat_template = chat_template | |
tokenizer = AutoTokenizer.from_pretrained(peft_model_id) | |
# make embedding resizing configurable? | |
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
model = PeftModel.from_pretrained(model, peft_model_id) | |
class ChatCompletion: | |
def __init__(self, model, tokenizer, system_prompt=None): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True) | |
self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True) | |
# set the model in inference mode | |
self.model.eval() | |
self.system_prompt = system_prompt | |
def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0): | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
messages = [] | |
if message_history is not None: | |
messages.extend(message_history) | |
elif system_prompt or self.system_prompt: | |
system_prompt = system_prompt or self.system_prompt | |
messages.append({"role": "system", "content":system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False) | |
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
generation_kwargs = dict(max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=0.95, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.2 | |
) | |
generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs) | |
return generated_text | |
def get_chat_completion(self, message, history): | |
messages = [] | |
if self.system_prompt: | |
messages.append({"role": "system", "content":self.system_prompt}) | |
for user_message, assistant_message in history: | |
messages.append({"role": "user", "content": user_message}) | |
messages.append({"role": "system", "content": assistant_message}) | |
messages.append({"role": "user", "content": message}) | |
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = self.tokenizer(chat_prompt, return_tensors="pt") | |
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
generation_kwargs = dict(inputs, | |
streamer=self.streamer, | |
max_new_tokens=2048, | |
temperature=0.2, | |
top_p=0.95, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
repetition_penalty=1.2, | |
) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in self.streamer: | |
generated_text += new_text.replace(self.tokenizer.eos_token, "") | |
yield generated_text | |
thread.join() | |
return generated_text | |
def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0): | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
messages = [] | |
if message_history is not None: | |
messages.extend(message_history) | |
elif system_prompt or self.system_prompt: | |
system_prompt = system_prompt or self.system_prompt | |
messages.append({"role": "system", "content":system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False) | |
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
generation_kwargs = dict(max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=0.95, | |
do_sample=True, | |
repetition_penalty=1.1) | |
outputs = self.model.generate(**inputs, **generation_kwargs) | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.") | |
gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True) | |