Spaces:
Sleeping
Sleeping
import gradio as gr | |
from openai import OpenAI | |
import time | |
import html | |
def predict(message, history, character, api_key, progress=gr.Progress()): | |
client = OpenAI(api_key=api_key) | |
history_openai_format = [] | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message}) | |
response = client.chat.completions.create( | |
model='gpt-4', | |
messages=history_openai_format, | |
temperature=1.0, | |
stream=True | |
) | |
partial_message = "" | |
for chunk in progress.tqdm(response, desc="Generating"): | |
if chunk.choices[0].delta.content: | |
partial_message += chunk.choices[0].delta.content | |
yield partial_message | |
time.sleep(0.01) | |
def format_history(history): | |
html_content = "" | |
for human, ai in history: | |
human_formatted = html.escape(human).replace('\n', '<br>') | |
html_content += f'<div class="message user-message"><strong>You:</strong> {human_formatted}</div>' | |
if ai: | |
ai_formatted = html.escape(ai).replace('\n', '<br>') | |
html_content += f'<div class="message ai-message"><strong>AI:</strong> {ai_formatted}</div>' | |
return html_content | |
css = """ | |
#chat-display { | |
height: 600px; | |
overflow-y: auto; | |
border: 1px solid #ccc; | |
padding: 10px; | |
margin-bottom: 10px; | |
} | |
#chat-display::-webkit-scrollbar { | |
width: 10px; | |
} | |
#chat-display::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
} | |
#chat-display::-webkit-scrollbar-thumb { | |
background: #888; | |
} | |
#chat-display::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
.message { | |
margin-bottom: 10px; | |
max-height: 300px; | |
overflow-y: auto; | |
word-wrap: break-word; | |
} | |
.user-message { | |
background-color: #e6f3ff; | |
padding: 5px; | |
border-radius: 5px; | |
} | |
.ai-message { | |
background-color: #f0f0f0; | |
padding: 5px; | |
border-radius: 5px; | |
} | |
""" | |
js = """ | |
function maintainScroll(element_id) { | |
let element = document.getElementById(element_id); | |
let shouldScroll = element.scrollTop + element.clientHeight === element.scrollHeight; | |
let previousScrollTop = element.scrollTop; | |
return function() { | |
if (!shouldScroll) { | |
element.scrollTop = previousScrollTop; | |
} else { | |
element.scrollTop = element.scrollHeight; | |
} | |
} | |
} | |
let scrollMaintainer = maintainScroll('chat-display'); | |
setInterval(scrollMaintainer, 100); | |
// Add event listener for Ctrl+Enter and prevent default Enter behavior | |
document.addEventListener('DOMContentLoaded', (event) => { | |
const textbox = document.querySelector('#your_message textarea'); | |
textbox.addEventListener('keydown', function(e) { | |
if (e.ctrlKey && e.key === 'Enter') { | |
e.preventDefault(); | |
document.querySelector('#your_message button').click(); | |
} else if (e.key === 'Enter' && !e.shiftKey) { | |
e.preventDefault(); | |
const start = this.selectionStart; | |
const end = this.selectionEnd; | |
this.value = this.value.substring(0, start) + "\\n" + this.value.substring(end); | |
this.selectionStart = this.selectionEnd = start + 1; | |
} | |
}); | |
}); | |
""" | |
with gr.Blocks(css=css, js=js) as demo: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>My Chatbot</h1>") | |
chat_history = gr.State([]) | |
chat_display = gr.HTML(elem_id="chat-display") | |
msg = gr.Textbox( | |
label="Your message", | |
lines=2, | |
max_lines=10, | |
placeholder="Type your message here... (Press Ctrl+Enter to send, Enter for new line)", | |
elem_id="your_message" | |
) | |
clear = gr.Button("Clear") | |
dropdown = gr.Dropdown( | |
["Character 1", "Character 2", "Character 3", "Character 4", "Character 5", "Character 6", "Character 7", "Character 8", "Character 9", "Character 10", "Character 11", "Character 12", "Character 13"], | |
label="Characters", | |
info="Select the character that you'd like to speak to", | |
value="Character 1" | |
) | |
api_key = gr.Textbox(type="password", label="OpenAI API Key") | |
def user(user_message, history): | |
history.append([user_message, None]) | |
return "", history, format_history(history) | |
def bot(history, character, api_key): | |
user_message = history[-1][0] | |
bot_message_generator = predict(user_message, history[:-1], character, api_key) | |
for chunk in bot_message_generator: | |
history[-1][1] = chunk | |
yield history, format_history(history) | |
msg.submit(user, [msg, chat_history], [msg, chat_history, chat_display]).then( | |
bot, [chat_history, dropdown, api_key], [chat_history, chat_display] | |
) | |
clear.click(lambda: ([], []), None, [chat_history, chat_display], queue=False) | |
dropdown.change(lambda x: ([], []), dropdown, [chat_history, chat_display]) | |
demo.queue() | |
demo.launch(max_threads=20) |