Spaces:
Runtime error
Runtime error
| ''' | |
| This script creates a Gradio chatbot interface for the ibm-granite/granite-3.3-8b-instruct model. | |
| Key Features: | |
| - Loads the model and tokenizer from Hugging Face Hub. | |
| - Uses a chat interface for interactive conversations. | |
| - Manages chat history to maintain context. | |
| - Handles API key management through Hugging Face Spaces secrets. | |
| ''' | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| import os | |
| # --- Configuration --- | |
| MODEL_ID = "ibm-granite/granite-3.3-8b-instruct" | |
| # --- Model and Tokenizer Loading --- | |
| def load_model_and_tokenizer(): | |
| '''Load the model and tokenizer, handling potential errors.''' | |
| try: | |
| # Securely get the Hugging Face token from secrets | |
| hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HUGGINGFACE_TOKEN secret not found. Please add it to your Space settings.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map=device, | |
| torch_dtype=torch.bfloat16, | |
| token=hf_token | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token) | |
| return model, tokenizer, device | |
| except Exception as e: | |
| # Provide a user-friendly error message | |
| raise RuntimeError(f"Failed to load model or tokenizer: {e}") | |
| model, tokenizer, device = load_model_and_tokenizer() | |
| # --- Chatbot Logic --- | |
| def chat_function(message, history): | |
| ''' | |
| This function processes the user's message and returns the model's response. | |
| ''' | |
| # Set seed for reproducibility | |
| set_seed(42) | |
| # Format the conversation history for the model | |
| conv = [] | |
| for user_msg, model_msg in history: | |
| conv.append({"role": "user", "content": user_msg}) | |
| conv.append({"role": "assistant", "content": model_msg}) | |
| conv.append({"role": "user", "content": message}) | |
| # Tokenize the input | |
| input_ids = tokenizer.apply_chat_template( | |
| conv, | |
| return_tensors="pt", | |
| thinking=False, # Set to False for direct response | |
| add_generation_prompt=True | |
| ).to(device) | |
| # Generate the response | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7, | |
| ) | |
| # Decode the prediction | |
| prediction = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True) | |
| return prediction | |
| # --- Gradio Interface --- | |
| def create_gradio_interface(): | |
| '''Create and return the Gradio ChatInterface.''' | |
| return gr.ChatInterface( | |
| fn=chat_function, | |
| title="Granite 3.3 8B Chatbot", | |
| description="A chatbot powered by the ibm-granite/granite-3.3-8b-instruct model. Ask any question!", | |
| theme="soft", | |
| examples=[ | |
| ["Hello, who are you?"], | |
| ["What is the capital of France?"], | |
| ["Explain the theory of relativity in simple terms."] | |
| ] | |
| ) | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| chatbot_interface = create_gradio_interface() | |
| chatbot_interface.launch() | |