|
import gradio as gr |
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
EMOTIPAL_HTML = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>EmotiPal</title> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
height: 100vh; |
|
margin: 0; |
|
background-color: #f0f0f0; |
|
} |
|
.container { |
|
background-color: white; |
|
padding: 20px; |
|
border-radius: 10px; |
|
box-shadow: 0 0 10px rgba(0,0,0,0.1); |
|
text-align: center; |
|
} |
|
#pet-image { |
|
width: 200px; |
|
height: 200px; |
|
margin: 20px auto; |
|
border-radius: 50%; |
|
background-color: #ddd; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
font-size: 100px; |
|
} |
|
button { |
|
margin: 10px; |
|
padding: 10px 20px; |
|
font-size: 16px; |
|
cursor: pointer; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>EmotiPal</h1> |
|
<div id="pet-image">πΆ</div> |
|
<h2 id="greeting">How are you feeling today?</h2> |
|
<div> |
|
<button onclick="setMood('happy')">π Happy</button> |
|
<button onclick="setMood('sad')">π’ Sad</button> |
|
<button onclick="setMood('angry')">π Angry</button> |
|
<button onclick="setMood('anxious')">π° Anxious</button> |
|
</div> |
|
<p id="support-message"></p> |
|
<input type="text" id="query-input" placeholder="Ask EmotiPal something (optional)"> |
|
<button onclick="askLlama()">Ask Llama 3</button> |
|
<p id="llama-response"></p> |
|
</div> |
|
|
|
<script> |
|
const supportMessages = { |
|
happy: "That's wonderful! Your happiness is contagious. Why not share your joy with someone today?", |
|
sad: "I'm here for you. Remember, it's okay to feel sad sometimes. How about we do a quick gratitude exercise?", |
|
angry: "I understand you're feeling frustrated. Let's try a deep breathing exercise to help calm down.", |
|
anxious: "You're not alone in feeling anxious. How about we focus on something positive you're looking forward to?" |
|
}; |
|
|
|
let currentMood = ''; |
|
|
|
function setMood(mood) { |
|
currentMood = mood; |
|
const greeting = document.getElementById('greeting'); |
|
const supportMessage = document.getElementById('support-message'); |
|
const petImage = document.getElementById('pet-image'); |
|
|
|
greeting.textContent = `I see you're feeling ${mood} today.`; |
|
supportMessage.textContent = supportMessages[mood]; |
|
|
|
switch(mood) { |
|
case 'happy': |
|
petImage.textContent = 'πΆ'; |
|
break; |
|
case 'sad': |
|
petImage.textContent = 'πΎ'; |
|
break; |
|
case 'angry': |
|
petImage.textContent = 'πΊ'; |
|
break; |
|
case 'anxious': |
|
petImage.textContent = 'π©'; |
|
break; |
|
} |
|
} |
|
|
|
function askLlama() { |
|
const queryInput = document.getElementById('query-input').value; |
|
const llamaQuery = currentMood ? |
|
`I'm feeling ${currentMood}. ${queryInput}` : |
|
queryInput || "Can you give me some advice or encouragement?"; |
|
|
|
// This function needs to be implemented in the Python backend |
|
getLlamaResponse(llamaQuery).then(response => { |
|
document.getElementById('llama-response').textContent = response; |
|
}); |
|
} |
|
|
|
// This function will be replaced by the actual implementation in Gradio |
|
function getLlamaResponse(query) { |
|
return new Promise(resolve => { |
|
setTimeout(() => { |
|
resolve("This is a placeholder response from Llama 3. Implement the actual API call in the Gradio backend."); |
|
}, 1000); |
|
}); |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
@spaces.GPU(duration=120) |
|
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str: |
|
""" Generate a streaming response using the llama3-8b model. """ |
|
conversation = [] |
|
for user, assistant in history: |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": "user", "content": message}) |
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
eos_token_id=terminators, |
|
) |
|
if temperature == 0: |
|
generate_kwargs['do_sample'] = False |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
return "".join(outputs) |
|
|
|
def get_llama_response(query): |
|
response = chat_llama3_8b(query, [], temperature=0.7, max_new_tokens=256) |
|
return response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(EMOTIPAL_HTML) |
|
query_input = gr.Textbox(visible=False) |
|
response_output = gr.Textbox(visible=False) |
|
gr.Interface( |
|
fn=get_llama_response, |
|
inputs=query_input, |
|
outputs=response_output, |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |