from huggingface_hub import Repository import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import os # Load the model from your Hugging Face account model_name = "adi1193/mistral-postv6" repository = Repository(model_name, clone_from="adi1193/mistral-postv6") # Load the model model_path = repository.local_dir model = AutoModelForSeq2SeqLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) def format_prompt(message, history, enable_hinglish=False): prompt = "" # Adding the Hinglish prompt if enable_hinglish and not any("[INST] You are a Hinglish LLM." in user_prompt for user_prompt, bot_response in history): prompt += Hinglish_Prompt for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, enable_hinglish=False): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = { "model": model, "tokenizer": tokenizer, "max_length": max_new_tokens + len(tokenizer.encode(prompt)), "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "do_sample": True, "seed": 42, } formatted_prompt = format_prompt(prompt, history, enable_hinglish) input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") output = model.generate(input_ids, **generate_kwargs) return tokenizer.decode(output[0], skip_special_tokens=True) additional_inputs=[ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ), gr.Checkbox( label="Hinglish", value=False, interactive=True, info="Enables the MistralTalk to talk in Hinglish (Combination of Hindi and English)", ) ] css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.HTML("

MistralTalk🗣️

") gr.HTML("

In this demo, you can chat with Mistral-8x7B model. 💬

") gr.HTML("

Learn more about the model here. 📚

") gr.ChatInterface( generate, additional_inputs=additional_inputs, theme = gr.themes.Soft(), examples=[["What is the interest?"], ["How does the universe work?"],["What can you do?"],["What is quantum mechanics?"],["Do you believe in an after life?"]] ) if __name__ == "__main__": demo.launch()