eaglelandsonce's picture
Update app.py
395335a verified
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''<div><h1 style="text-align: center;">Meta Llama3 8B with EmotiPal</h1></div>'''
LICENSE = """<p/>---Built with Meta Llama 3"""
PLACEHOLDER = """"""
css = """
h1 { text-align: center; display: block;}
#duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh;}
"""
# EmotiPal HTML
EMOTIPAL_HTML = """
<div style="font-family: Arial, sans-serif; text-align: center; background-color: #f0f0f0; padding: 20px; border-radius: 10px;">
<h1>EmotiPal</h1>
<div id="pet-image" style="width: 200px; height: 200px; margin: 20px auto; border-radius: 50%; background-color: #ddd; display: flex; justify-content: center; align-items: center; font-size: 100px;">
🐢
</div>
<h2 id="greeting">How are you feeling today?</h2>
</div>
"""
# Define mood messages
MOOD_MESSAGES = {
"happy": "That's wonderful! Your happiness is contagious. Why not share your joy with someone today?",
"sad": "I'm here for you. Remember, it's okay to feel sad sometimes. How about we do a quick gratitude exercise?",
"angry": "I understand you're feeling frustrated. Let's try a deep breathing exercise to help calm down.",
"anxious": "You're not alone in feeling anxious. How about we focus on something positive you're looking forward to?"
}
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
""" Generate a streaming response using the llama3-8b model. """
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
return "".join(outputs)
# EmotiPal function
def set_mood(mood, query_input):
pet_emoji = {"happy": "🐢", "sad": "🐾", "angry": "🐺", "anxious": "🐩"}
greeting = f"I see you're feeling {mood} today."
pet_image = pet_emoji.get(mood, "🐢")
support_message = MOOD_MESSAGES.get(mood, "")
# Generate a query for Llama 3 based on the mood
if query_input:
llama_query = f"I'm feeling {mood}. {query_input}"
else:
llama_query = f"I'm feeling {mood}. Can you give me some advice or encouragement?"
return greeting, pet_image, support_message, llama_query
# Function to handle Llama 3 response
def get_llama_response(query, history):
response = chat_llama3_8b(query, history, temperature=0.7, max_new_tokens=256)
return response
# Gradio block
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
gr.HTML(EMOTIPAL_HTML)
mood_buttons = gr.Radio(
["happy", "sad", "angry", "anxious"],
label="Select your mood",
info="Click on your current mood",
)
query_input = gr.Textbox(label="Ask EmotiPal something (optional)")
greeting_output = gr.Textbox(label="EmotiPal says:")
pet_image_output = gr.Textbox(label="Pet")
support_message_output = gr.Textbox(label="Support Message")
llama_query_output = gr.Textbox(label="Query for Llama 3")
llama_response_output = gr.Textbox(label="Llama 3 Response")
mood_buttons.change(
set_mood,
inputs=[mood_buttons, query_input],
outputs=[greeting_output, pet_image_output, support_message_output, llama_query_output]
)
ask_llama_button = gr.Button("Ask Llama 3")
ask_llama_button.click(
get_llama_response,
inputs=[llama_query_output, gr.State([])],
outputs=[llama_response_output]
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False),
],
examples=[
['How to setup a human base on Mars? Give short answer.'],
["Explain theory of relativity to me like I'm 8 years old."],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()