|
import gradio as gr |
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
DESCRIPTION = '''<div><h1 style="text-align: center;">Meta Llama3 8B with EmotiPal</h1></div>''' |
|
|
|
LICENSE = """<p/>---Built with Meta Llama 3""" |
|
|
|
PLACEHOLDER = """""" |
|
|
|
css = """ |
|
h1 { text-align: center; display: block;} |
|
#duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh;} |
|
""" |
|
|
|
|
|
EMOTIPAL_HTML = """ |
|
<div style="font-family: Arial, sans-serif; text-align: center; background-color: #f0f0f0; padding: 20px; border-radius: 10px;"> |
|
<h1>EmotiPal</h1> |
|
<div id="pet-image" style="width: 200px; height: 200px; margin: 20px auto; border-radius: 50%; background-color: #ddd; display: flex; justify-content: center; align-items: center; font-size: 100px;"> |
|
πΆ |
|
</div> |
|
<h2 id="greeting">How are you feeling today?</h2> |
|
</div> |
|
""" |
|
|
|
|
|
MOOD_MESSAGES = { |
|
"happy": "That's wonderful! Your happiness is contagious. Why not share your joy with someone today?", |
|
"sad": "I'm here for you. Remember, it's okay to feel sad sometimes. How about we do a quick gratitude exercise?", |
|
"angry": "I understand you're feeling frustrated. Let's try a deep breathing exercise to help calm down.", |
|
"anxious": "You're not alone in feeling anxious. How about we focus on something positive you're looking forward to?" |
|
} |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
@spaces.GPU(duration=120) |
|
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str: |
|
""" Generate a streaming response using the llama3-8b model. """ |
|
conversation = [] |
|
for user, assistant in history: |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": "user", "content": message}) |
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
eos_token_id=terminators, |
|
) |
|
if temperature == 0: |
|
generate_kwargs['do_sample'] = False |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
return "".join(outputs) |
|
|
|
|
|
def set_mood(mood, query_input): |
|
pet_emoji = {"happy": "πΆ", "sad": "πΎ", "angry": "πΊ", "anxious": "π©"} |
|
greeting = f"I see you're feeling {mood} today." |
|
pet_image = pet_emoji.get(mood, "πΆ") |
|
support_message = MOOD_MESSAGES.get(mood, "") |
|
|
|
|
|
if query_input: |
|
llama_query = f"I'm feeling {mood}. {query_input}" |
|
else: |
|
llama_query = f"I'm feeling {mood}. Can you give me some advice or encouragement?" |
|
|
|
return greeting, pet_image, support_message, llama_query |
|
|
|
|
|
def get_llama_response(query, history): |
|
response = chat_llama3_8b(query, history, temperature=0.7, max_new_tokens=256) |
|
return response |
|
|
|
|
|
with gr.Blocks(fill_height=True, css=css) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.HTML(EMOTIPAL_HTML) |
|
mood_buttons = gr.Radio( |
|
["happy", "sad", "angry", "anxious"], |
|
label="Select your mood", |
|
info="Click on your current mood", |
|
) |
|
query_input = gr.Textbox(label="Ask EmotiPal something (optional)") |
|
greeting_output = gr.Textbox(label="EmotiPal says:") |
|
pet_image_output = gr.Textbox(label="Pet") |
|
support_message_output = gr.Textbox(label="Support Message") |
|
llama_query_output = gr.Textbox(label="Query for Llama 3") |
|
llama_response_output = gr.Textbox(label="Llama 3 Response") |
|
|
|
mood_buttons.change( |
|
set_mood, |
|
inputs=[mood_buttons, query_input], |
|
outputs=[greeting_output, pet_image_output, support_message_output, llama_query_output] |
|
) |
|
|
|
ask_llama_button = gr.Button("Ask Llama 3") |
|
ask_llama_button.click( |
|
get_llama_response, |
|
inputs=[llama_query_output, gr.State([])], |
|
outputs=[llama_response_output] |
|
) |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') |
|
gr.ChatInterface( |
|
fn=chat_llama3_8b, |
|
chatbot=chatbot, |
|
fill_height=True, |
|
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False, render=False), |
|
additional_inputs=[ |
|
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature", render=False), |
|
gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens", render=False), |
|
], |
|
examples=[ |
|
['How to setup a human base on Mars? Give short answer.'], |
|
["Explain theory of relativity to me like I'm 8 years old."], |
|
['What is 9,000 * 9,000?'], |
|
['Write a pun-filled happy birthday message to my friend Alex.'], |
|
['Justify why a penguin might make a good king of the jungle.'] |
|
], |
|
cache_examples=False, |
|
) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |