import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer from threading import Thread import gradio as gr from peft import PeftModel model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base" peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI" model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto") # tokenizer.chat_template = chat_template tokenizer = AutoTokenizer.from_pretrained(peft_model_id) # make embedding resizing configurable? model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) model = PeftModel.from_pretrained(model, peft_model_id) class ChatCompletion: def __init__(self, model, tokenizer, system_prompt=None): self.model = model self.tokenizer = tokenizer self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True) self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True) # set the model in inference mode self.model.eval() self.system_prompt = system_prompt def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0): if temperature < 1e-2: temperature = 1e-2 messages = [] if message_history is not None: messages.extend(message_history) elif system_prompt or self.system_prompt: system_prompt = system_prompt or self.system_prompt messages.append({"role": "system", "content":system_prompt}) messages.append({"role": "user", "content": prompt}) chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False) # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. generation_kwargs = dict(max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.95, do_sample=True, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.2 ) generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs) return generated_text def get_chat_completion(self, message, history): messages = [] if self.system_prompt: messages.append({"role": "system", "content":self.system_prompt}) for user_message, assistant_message in history: messages.append({"role": "user", "content": user_message}) messages.append({"role": "system", "content": assistant_message}) messages.append({"role": "user", "content": message}) chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(chat_prompt, return_tensors="pt") # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. generation_kwargs = dict(inputs, streamer=self.streamer, max_new_tokens=2048, temperature=0.2, top_p=0.95, eos_token_id=tokenizer.eos_token_id, do_sample=True, repetition_penalty=1.2, ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in self.streamer: generated_text += new_text.replace(self.tokenizer.eos_token, "") yield generated_text thread.join() return generated_text def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0): if temperature < 1e-2: temperature = 1e-2 messages = [] if message_history is not None: messages.extend(message_history) elif system_prompt or self.system_prompt: system_prompt = system_prompt or self.system_prompt messages.append({"role": "system", "content":system_prompt}) messages.append({"role": "user", "content": prompt}) chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False) # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. generation_kwargs = dict(max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.95, do_sample=True, repetition_penalty=1.1) outputs = self.model.generate(**inputs, **generation_kwargs) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.") gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)