|
import os |
|
from dotenv import load_dotenv |
|
import gradio as gr |
|
from gradio.components import Textbox, Button, Slider, Checkbox |
|
from AinaTheme import theme |
|
from huggingface_hub import InferenceClient |
|
from urllib.error import HTTPError |
|
|
|
load_dotenv() |
|
|
|
def generate(prompt, model_parameters): |
|
|
|
try: |
|
output = client.text_generation(prompt, **model_parameters, return_full_text=True) |
|
return output |
|
except HTTPError as err: |
|
if err.code == 400: |
|
gr.Warning("The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET.") |
|
except: |
|
gr.Warning('Inference endpoint is not available right now. Please try again later.') |
|
|
|
|
|
client = InferenceClient( |
|
os.environ.get("HF_INFERENCE_ENDPOINT_URL"), |
|
token=os.environ.get("HF_INFERENCE_ENDPOINT_TOKEN") |
|
) |
|
|
|
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100)) |
|
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=100)) |
|
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True) == "True" |
|
|
|
|
|
def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): |
|
if input_.strip() == "": |
|
gr.Warning('Not possible to inference an empty input') |
|
return None |
|
|
|
model_parameters = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"do_sample": do_sample, |
|
"temperature": temperature |
|
} |
|
|
|
output = generate(input_, model_parameters) |
|
|
|
return output |
|
|
|
def change_interactive(text): |
|
if len(text.strip()) > MAX_INPUT_CHARACTERS: |
|
return gr.update(interactive = True), gr.update(interactive = False) |
|
if (len(text) == 0): |
|
return gr.update(interactive = True), gr.update(interactive = False) |
|
return gr.update(interactive = True), gr.update(interactive = True) |
|
|
|
def clear(): |
|
return ( |
|
None, |
|
None, |
|
gr.update(value=MAX_NEW_TOKENS), |
|
gr.update(value=1.2), |
|
gr.update(value=50), |
|
gr.update(value=0.95), |
|
gr.update(value=True), |
|
gr.update(value=0.5), |
|
) |
|
|
|
def gradio_app(): |
|
with gr.Blocks(theme=theme) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=0.1): |
|
gr.Image("ginesta_small.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False) |
|
with gr.Column(): |
|
gr.Markdown( |
|
"""# AIMestre |
|
|
|
Basat en el model [Flor](https://huggingface.co/projecte-aina/FLOR-6.3B) del projecte AINA. |
|
|
|
""" |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(variant="panel"): |
|
placeholder_max_token = Textbox( |
|
visible=False, |
|
interactive=False, |
|
value= MAX_INPUT_CHARACTERS |
|
) |
|
input_ = Textbox( |
|
lines=11, |
|
label="Posa aquí el teu escrit en català.", |
|
placeholder="e.g. El mercat del barri és fantàstic hi pots trobar." |
|
) |
|
with gr.Row(variant="panel", equal_height=True): |
|
gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""") |
|
gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span> / {MAX_INPUT_CHARACTERS}</span>""") |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI): |
|
max_new_tokens = Slider( |
|
minimum=1, |
|
maximum=200, |
|
step=1, |
|
value=MAX_NEW_TOKENS, |
|
label="Max tokens" |
|
) |
|
repetition_penalty = Slider( |
|
minimum=0.1, |
|
maximum=10, |
|
step=0.1, |
|
value=1.2, |
|
label="Repetition penalty" |
|
) |
|
top_k = Slider( |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
label="Top k" |
|
) |
|
top_p = Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.95, |
|
label="Top p" |
|
) |
|
do_sample = Checkbox( |
|
value=True, |
|
label="Do sample" |
|
) |
|
temperature = Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
label="Temperature" |
|
) |
|
with gr.Column(variant="panel"): |
|
output = Textbox( |
|
lines=11, |
|
label="El mestre diu...", |
|
interactive=False, |
|
show_copy_button=True |
|
) |
|
with gr.Row(variant="panel"): |
|
clear_btn = Button( |
|
"Clear", |
|
) |
|
submit_btn = Button( |
|
"Submit", |
|
variant="primary", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.5): |
|
gr.Examples( |
|
label="Short prompts:", |
|
examples=[ |
|
["""La capital de Suècia"""], |
|
], |
|
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
|
outputs=output, |
|
fn=submit_input, |
|
) |
|
|
|
gr.Examples( |
|
label="Zero-shot prompts", |
|
examples=[ |
|
["Tradueix del Castellà al Català la següent frase: \"Eso es pan comido.\" \nTraducció:"], |
|
], |
|
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
|
outputs=output, |
|
fn=submit_input, |
|
) |
|
gr.Examples( |
|
label="Few-Shot prompts:", |
|
examples=[ |
|
["""Oració: Els sons melòdics produeixen una sensació de calma i benestar en l'individu. \nParàfrasi: La música és molt relaxant i reconfortant.\n----\nOració: L'animal domèstic mostra una gran alegria i satisfacció. \nParàfrasi: El gos és molt feliç. \n----\nOració: El vehicle es va trencar i vaig haver de contactar amb el servei de remolc perquè el transportés. \nParàfrasi: El cotxe es va trencar i vaig haver de trucar la grua. \n----\nOració: El professor va explicar els conceptes de manera clara i concisa. \nParàfrasi:"""], |
|
], |
|
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
|
outputs=output, |
|
fn=submit_input, |
|
) |
|
|
|
|
|
|
|
input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False) |
|
|
|
input_.change(fn=None, inputs=[input_], api_name=False, js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """) |
|
|
|
input_.change(fn=None, inputs=[input_, placeholder_max_token], api_name=False, js="""(i, m) => { |
|
document.getElementById('inputlenght').textContent = i.length + ' ' |
|
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; |
|
}""") |
|
|
|
clear_btn.click(fn=clear, inputs=[], outputs=[input_, output, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], queue=False, api_name=False) |
|
submit_btn.click(fn=submit_input, inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], outputs=[output], api_name="get-results") |
|
|
|
demo.launch(show_api=True) |
|
|
|
if __name__ == "__main__": |
|
gradio_app() |