jewelzufo's picture
Update app.py
36b0eaf verified
'''
This script creates a Gradio chatbot interface for the ibm-granite/granite-3.3-8b-instruct model.
Key Features:
- Loads the model and tokenizer from Hugging Face Hub.
- Uses a chat interface for interactive conversations.
- Manages chat history to maintain context.
- Handles API key management through Hugging Face Spaces secrets.
'''
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import os
# --- Configuration ---
MODEL_ID = "ibm-granite/granite-3.3-8b-instruct"
# --- Model and Tokenizer Loading ---
def load_model_and_tokenizer():
'''Load the model and tokenizer, handling potential errors.'''
try:
# Securely get the Hugging Face token from secrets
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_TOKEN secret not found. Please add it to your Space settings.")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map=device,
torch_dtype=torch.bfloat16,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token)
return model, tokenizer, device
except Exception as e:
# Provide a user-friendly error message
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
model, tokenizer, device = load_model_and_tokenizer()
# --- Chatbot Logic ---
def chat_function(message, history):
'''
This function processes the user's message and returns the model's response.
'''
# Set seed for reproducibility
set_seed(42)
# Format the conversation history for the model
conv = []
for user_msg, model_msg in history:
conv.append({"role": "user", "content": user_msg})
conv.append({"role": "assistant", "content": model_msg})
conv.append({"role": "user", "content": message})
# Tokenize the input
input_ids = tokenizer.apply_chat_template(
conv,
return_tensors="pt",
thinking=False, # Set to False for direct response
add_generation_prompt=True
).to(device)
# Generate the response
output = model.generate(
input_ids,
max_new_tokens=1024,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
)
# Decode the prediction
prediction = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
return prediction
# --- Gradio Interface ---
def create_gradio_interface():
'''Create and return the Gradio ChatInterface.'''
return gr.ChatInterface(
fn=chat_function,
title="Granite 3.3 8B Chatbot",
description="A chatbot powered by the ibm-granite/granite-3.3-8b-instruct model. Ask any question!",
theme="soft",
examples=[
["Hello, who are you?"],
["What is the capital of France?"],
["Explain the theory of relativity in simple terms."]
]
)
# --- Main Execution ---
if __name__ == "__main__":
chatbot_interface = create_gradio_interface()
chatbot_interface.launch()