shuvom's picture
Update app.py
11f5672
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
from threading import Thread
import gradio as gr
from peft import PeftModel
model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")
# tokenizer.chat_template = chat_template
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
model = PeftModel.from_pretrained(model, peft_model_id)
class ChatCompletion:
def __init__(self, model, tokenizer, system_prompt=None):
self.model = model
self.tokenizer = tokenizer
self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
# set the model in inference mode
self.model.eval()
self.system_prompt = system_prompt
def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
if temperature < 1e-2:
temperature = 1e-2
messages = []
if message_history is not None:
messages.extend(message_history)
elif system_prompt or self.system_prompt:
system_prompt = system_prompt or self.system_prompt
messages.append({"role": "system", "content":system_prompt})
messages.append({"role": "user", "content": prompt})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2
)
generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
return generated_text
def get_chat_completion(self, message, history):
messages = []
if self.system_prompt:
messages.append({"role": "system", "content":self.system_prompt})
for user_message, assistant_message in history:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "system", "content": assistant_message})
messages.append({"role": "user", "content": message})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt")
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs,
streamer=self.streamer,
max_new_tokens=2048,
temperature=0.2,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.2,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in self.streamer:
generated_text += new_text.replace(self.tokenizer.eos_token, "")
yield generated_text
thread.join()
return generated_text
def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
if temperature < 1e-2:
temperature = 1e-2
messages = []
if message_history is not None:
messages.extend(message_history)
elif system_prompt or self.system_prompt:
system_prompt = system_prompt or self.system_prompt
messages.append({"role": "system", "content":system_prompt})
messages.append({"role": "user", "content": prompt})
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
repetition_penalty=1.1)
outputs = self.model.generate(**inputs, **generation_kwargs)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")
gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)