|
|
import gradio as gr |
|
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
base_model_name = "mistralai/Mistral-7B-v0.1" |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
base_model_name, |
|
|
token=hf_token |
|
|
) |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_name, |
|
|
quantization_config=quantization_config, |
|
|
device_map="auto", |
|
|
token=hf_token |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
"hin123123/theralingua-mistral-7b-word", |
|
|
token=hf_token |
|
|
) |
|
|
|
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=512, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
|
|
|
def chatbot_fn(message, history): |
|
|
chat_history = [{"role": "system", "content": "You are a helpful assistant trained to provide concise and accurate answers. For training-related queries, provide detailed steps."}] |
|
|
for user_msg, bot_msg in history: |
|
|
chat_history.append({"role": "user", "content": user_msg}) |
|
|
chat_history.append({"role": "assistant", "content": bot_msg}) |
|
|
chat_history.append({"role": "user", "content": message}) |
|
|
response = pipe(chat_history)[0]["generated_text"][-1]["content"] |
|
|
return response |
|
|
|
|
|
|
|
|
gr.ChatInterface( |
|
|
fn=chatbot_fn, |
|
|
chatbot=gr.Chatbot(height=500), |
|
|
textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7), |
|
|
title="Theralingua Mistral Chatbot", |
|
|
description="Chat with the Theralingua AI powered by Mistral-7B!", |
|
|
theme="soft", |
|
|
retry_btn=None, |
|
|
undo_btn="Undo", |
|
|
clear_btn="Clear" |
|
|
).launch() |