Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import gradio as gr | |
from gradio.components import Textbox, Button | |
from AinaTheme import theme | |
from urllib.error import HTTPError | |
from rag import RAG | |
from utils import setup | |
setup() | |
rag = RAG( | |
hf_token=os.getenv("HF_TOKEN"), | |
embeddings_model=os.getenv("EMBEDDINGS"), | |
model_name=os.getenv("MODEL"), | |
) | |
def generate(prompt): | |
try: | |
output = rag.get_response(prompt) | |
return output | |
except HTTPError as err: | |
if err.code == 400: | |
gr.Warning( | |
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET." | |
) | |
except: | |
gr.Warning( | |
"Inference endpoint is not available right now. Please try again later." | |
) | |
def submit_input(input_): | |
if input_.strip() == "": | |
gr.Warning("Not possible to inference an empty input") | |
return None | |
output = generate(input_) | |
return output | |
def change_interactive(text): | |
if len(text) == 0: | |
return gr.update(interactive=True), gr.update(interactive=False) | |
return gr.update(interactive=True), gr.update(interactive=True) | |
def clear(): | |
return ( | |
None, | |
None, | |
) | |
def gradio_app(): | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Row(): | |
with gr.Column(scale=0.1): | |
gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False) | |
with gr.Column(): | |
gr.Markdown( | |
"""# Demo de Retrieval-Augmented Generation per documents legals | |
🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia de IA que permet interrogar un repositori de documents amb preguntes | |
en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta | |
fent servir només la informació existent en els documents del repositori. | |
🎯 **Objectiu:** Aquest és un primer demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el | |
repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta primera versió explora prop de 2000 documents en català, | |
i genera la resposta fent servir el model Flor6.3b entrenat amb el dataset de QA generativa BSC-LT/RAG_Multilingual. | |
⚠️ **Advertencies**: Primera versió experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte. | |
Si us plau, tingueu-ho en compte quan exploreu aquest recurs. | |
""" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(variant="panel"): | |
input_ = Textbox( | |
lines=11, | |
label="Input", | |
placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?", | |
# value = "Quina és la finalitat del Servei Meteorològic de Catalunya?" | |
) | |
with gr.Column(variant="panel"): | |
output = Textbox( | |
lines=11, label="Output", interactive=False, show_copy_button=True | |
) | |
with gr.Row(variant="panel"): | |
clear_btn = Button( | |
"Clear", | |
) | |
submit_btn = Button("Submit", variant="primary", interactive=False) | |
input_.change( | |
fn=change_interactive, | |
inputs=[input_], | |
outputs=[clear_btn, submit_btn], | |
api_name=False, | |
) | |
input_.change( | |
fn=None, | |
inputs=[input_], | |
api_name=False, | |
js="""(i, m) => { | |
document.getElementById('inputlenght').textContent = i.length + ' ' | |
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; | |
}""", | |
) | |
clear_btn.click( | |
fn=clear, inputs=[], outputs=[input_, output], queue=False, api_name=False | |
) | |
submit_btn.click( | |
fn=submit_input, inputs=[input_], outputs=[output], api_name="get-results" | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
gr.Examples( | |
label="Short prompts:", | |
examples=[ | |
[""" Què diu el decret sobre la senyalització de les begudes alcohòliques i el tabac a Catalunya? """], | |
], | |
inputs=input_, | |
outputs=output, | |
fn=submit_input, | |
) | |
gr.Examples( | |
label="Short prompts:", | |
examples=[ | |
[""" Quina és la finalitat del Servei Meterològic de Catalunya ? """], | |
], | |
inputs=input_, | |
outputs=output, | |
fn=submit_input, | |
) | |
demo.launch(show_api=True) | |
if __name__ == "__main__": | |
gradio_app() |