my_space / app.py
CHOW Man Hin (240308836)
Add 4-bit quantization and GPU support
60e5ad7
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
import os
# Load Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN")
# Load base model and tokenizer with 4-bit quantization
base_model_name = "mistralai/Mistral-7B-v0.1"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Use 4-bit quantization for lower memory
bnb_4bit_compute_dtype=torch.bfloat16, # Optimize for GPU
bnb_4bit_quant_type="nf4", # NormalFloat4 for efficiency
bnb_4bit_use_double_quant=True # Double quantization for extra savings
)
tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
token=hf_token
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quantization_config,
device_map="auto", # Use GPU automatically
token=hf_token
)
# Load LoRA adapters
model = PeftModel.from_pretrained(
base_model,
"hin123123/theralingua-mistral-7b-word",
token=hf_token
)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7
)
# Chat function
def chatbot_fn(message, history):
chat_history = [{"role": "system", "content": "You are a helpful assistant trained to provide concise and accurate answers. For training-related queries, provide detailed steps."}]
for user_msg, bot_msg in history:
chat_history.append({"role": "user", "content": user_msg})
chat_history.append({"role": "assistant", "content": bot_msg})
chat_history.append({"role": "user", "content": message})
response = pipe(chat_history)[0]["generated_text"][-1]["content"]
return response
# Gradio interface
gr.ChatInterface(
fn=chatbot_fn,
chatbot=gr.Chatbot(height=500),
textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7),
title="Theralingua Mistral Chatbot",
description="Chat with the Theralingua AI powered by Mistral-7B!",
theme="soft",
retry_btn=None,
undo_btn="Undo",
clear_btn="Clear"
).launch()