Islam YAHIAOUI
Update UI
31e6eb8
raw
history blame
3.34 kB
import json
import gradio as gr
from huggingface_hub import InferenceClient
import os
import requests
from rag import run_rag
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def chat(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
message =run_rag(message, history)
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += str(token)
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.Chatbot(
label="Retrieval Augmented Generation News & Finance",
# avatar_images=[None, BOT_AVATAR],
show_copy_button=True,
likeable=True,
layout="bubble")
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
EXAMPLES = [
[ "Tell me about the latest news in the world ?"],
[ "Tell me about the increase in the price of Bitcoin ?"],
[ "Tell me about the actual situation in Ukraine ?"],
[ "Tell me about current situation in palestine ?"],
]
max_new_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
interactive=True,
label="Max new tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.6,
step=0.1,
visible=True,
interactive=True,
label="Temperature",
info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1,
value=0.9,
step=0.05,
visible=True,
interactive=True,
label="Top-p (nucleus sampling)",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
with gr.Blocks(
fill_height=True,
css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
) as main:
gr.ChatInterface(
chat,
chatbot=chatbot,
title="Retrieval Augmented Generation (RAG) Chatbot",
description="A chatbot that uses a RAG model to generate responses based on the input query.",
examples=EXAMPLES,
theme=theme,
fill_height=True,
multimodal=True,
additional_inputs=[
max_new_tokens,
temperature,
top_p,
],
)
with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="RAG") as demo:
gr.TabbedInterface([main] , tab_names=["Chatbot"] )
demo.launch()