Spaces:
Runtime error
Runtime error
import gradio as gr | |
from keras.models import load_model | |
from PIL import Image, ImageOps | |
import numpy as np | |
from huggingface_hub import InferenceClient | |
# Load the Keras model | |
model = load_model("keras_model.h5", compile=False) | |
# Load class labels from a file | |
with open("labels.txt", "r") as file: | |
class_names = [line.strip() for line in file] | |
# Initialize the Hugging Face Inference Client | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def classify_image(img): | |
# Resize and normalize the image for model prediction | |
image = ImageOps.fit(img, (224, 224), Image.Resampling.LANCZOS) | |
image_array = np.asarray(image) | |
normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1 | |
data = normalized_image_array.reshape((1, 224, 224, 3)) | |
# Predict the emotion using the model | |
prediction = model.predict(data) | |
index = np.argmax(prediction) | |
class_name = class_names[index] | |
confidence_score = prediction[0][index] | |
# Return the detected emotion and confidence score | |
return { | |
"Detected Emotion": class_name, | |
"Confidence Score": f"{confidence_score:.2f}" | |
} | |
def respond( | |
messages, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p | |
): | |
# Ensure messages are correctly formatted | |
formatted_messages = [] | |
for message in messages: | |
if message['content'] is None: | |
message['content'] = '' # Set to empty string if None | |
formatted_messages.append(message) | |
# Add system message at the beginning | |
formatted_messages.insert(0, {"role": "system", "content": system_message}) | |
# Proceed with chat completion | |
try: | |
response = "" | |
for message in client.chat_completion( | |
model_id='HuggingFaceH4/zephyr-7b-beta', | |
messages=formatted_messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
stream=True | |
): | |
token = message.choices[0].delta.content | |
response += token | |
return response | |
except Exception as e: | |
print(f"Error: {e}") | |
return "Sorry, there was an error processing your request." | |
def emotion_and_chat(img, system_message, max_tokens, temperature, top_p): | |
# Classify the image to detect emotion | |
emotion_result = classify_image(img) | |
detected_emotion = emotion_result["Detected Emotion"] | |
# Start chatbot conversation based on the detected emotion | |
initial_message = f"I detected that you're feeling {detected_emotion}. Let's talk about it." | |
chat_history = [{"role": "user", "content": initial_message}] | |
chat_response = respond(chat_history, system_message, max_tokens, temperature, top_p) | |
return chat_response | |
# Define custom CSS for styling | |
custom_css = """ | |
body { | |
font-family: Arial, sans-serif; | |
background-color: #000; | |
color: #f4f4f4; | |
} | |
.gradio-container { | |
border-radius: 10px; | |
padding: 20px; | |
background: linear-gradient(135deg, #ff0000, #008000); | |
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); | |
} | |
.gradio-container h1 { | |
font-family: Arial, sans-serif; | |
font-size: 2.5em; | |
text-align: center; | |
color: #fff; | |
} | |
.gradio-container p { | |
font-size: 1em; | |
text-align: center; | |
color: #c0c0c0; | |
} | |
.gradio-button { | |
background-color: #ff0000; | |
border: none; | |
color: #fff; | |
padding: 10px 20px; | |
font-size: 1em; | |
cursor: pointer; | |
border-radius: 5px; | |
transition: background-color 0.2s ease; | |
} | |
.gradio-button:hover { | |
background-color: #ff4d4d; | |
} | |
#output-container { | |
border-radius: 10px; | |
background-color: #008000; | |
padding: 20px; | |
color: #fff; | |
} | |
#output-container h3 { | |
font-family: Arial, sans-serif; | |
font-size: 1.5em; | |
color: #fff; | |
} | |
.gr-examples { | |
text-align: center; | |
} | |
.gr-example-img { | |
width: 100px; | |
border-radius: 5px; | |
margin: 5px; | |
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.2); | |
} | |
""" | |
# Define example images from URLs | |
examples = [ | |
"https://firebasestorage.googleapis.com/v0/b/hisia-4b65b.appspot.com/o/a-captivating-ukiyo-e-inspired-poster-featuring-a--wTg7L-f2Tfiy6K8w6aWnKA-KbGU9GSKSDGBbbxrCO65Mg.jpeg?alt=media&token=64590de9-e265-44ac-a766-aeecd455ed5d", | |
"https://firebasestorage.googleapis.com/v0/b/hisia-4b65b.appspot.com/o/poster-ai-themed-kenyan-female-silhoutte-written-l-PMIXpNWGQ8KaNNetQRVJuQ-B1TteyL-S5OTPZFXvfGybg.jpeg?alt=media&token=fc10f96d-403e-4f75-bd9c-810e0da36867", | |
"https://firebasestorage.googleapis.com/v0/b/hisia-4b65b.appspot.com/o/poster-ai-themed-kenyan-male-silhoutte-written-log-z3fqBD5bQOOj6uqGd_iXLQ-4aBfNy0ZTgmLlTsZh1dzIA.jpeg?alt=media&token=f218f160-d38e-482f-97a9-5442c2f251a7" | |
] | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=emotion_and_chat, | |
inputs=[ | |
gr.Image(type="pil", label="Upload an Image"), | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
], | |
outputs=gr.Chatbot(label="Chat with the AI"), | |
examples=examples, | |
title="HISIA: Emotion Detector and Chatbot", | |
description="Upload an image, and our AI will detect the emotion expressed in it and start a conversation with you.", | |
allow_flagging="never", | |
css=custom_css, | |
) | |
# Launch the Gradio interface | |
if __name__ == "__main__": | |
interface.launch() | |