corrector / app.py
rtous's picture
nc
ce62d0f
raw
history blame
8.82 kB
import os
from dotenv import load_dotenv
import gradio as gr
from gradio.components import Textbox, Button, Slider, Checkbox
from AinaTheme import theme
from huggingface_hub import InferenceClient
from urllib.error import HTTPError
load_dotenv()
def generate(prompt, model_parameters):
try:
output = client.text_generation(prompt, **model_parameters, return_full_text=True)
return output
except HTTPError as err:
if err.code == 400:
gr.Warning("The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET.")
except:
gr.Warning('Inference endpoint is not available right now. Please try again later.')
client = InferenceClient(
os.environ.get("HF_INFERENCE_ENDPOINT_URL"),
token=os.environ.get("HF_INFERENCE_ENDPOINT_TOKEN")
)
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100))
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=100))
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True) == "True"
def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
if input_.strip() == "":
gr.Warning('Not possible to inference an empty input')
return None
model_parameters = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"top_p": top_p,
"do_sample": do_sample,
"temperature": temperature
}
output = generate(input_, model_parameters)
return output
def change_interactive(text):
if len(text.strip()) > MAX_INPUT_CHARACTERS:
return gr.update(interactive = True), gr.update(interactive = False)
if (len(text) == 0):
return gr.update(interactive = True), gr.update(interactive = False)
return gr.update(interactive = True), gr.update(interactive = True)
def clear():
return (
None,
None,
gr.update(value=MAX_NEW_TOKENS),
gr.update(value=1.2),
gr.update(value=50),
gr.update(value=0.95),
gr.update(value=True),
gr.update(value=0.5),
)
def gradio_app():
with gr.Blocks(theme=theme) as demo:
with gr.Row():
with gr.Column(scale=0.1):
gr.Image("ginesta_small.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
with gr.Column():
gr.Markdown(
"""# AIMestre
Basat en el model [Flor](https://huggingface.co/projecte-aina/FLOR-6.3B) del projecte AINA.
"""
)
with gr.Row(equal_height=True):
with gr.Column(variant="panel"):
placeholder_max_token = Textbox(
visible=False,
interactive=False,
value= MAX_INPUT_CHARACTERS
)
input_ = Textbox(
lines=11,
label="Posa aquí el teu escrit en català.",
placeholder="e.g. El mercat del barri és fantàstic hi pots trobar."
)
with gr.Row(variant="panel", equal_height=True):
gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""")
gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span>&nbsp;/&nbsp;{MAX_INPUT_CHARACTERS}</span>""")
with gr.Row(variant="panel"):
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
max_new_tokens = Slider(
minimum=1,
maximum=200,
step=1,
value=MAX_NEW_TOKENS,
label="Max tokens"
)
repetition_penalty = Slider(
minimum=0.1,
maximum=10,
step=0.1,
value=1.2,
label="Repetition penalty"
)
top_k = Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top k"
)
top_p = Slider(
minimum=0.01,
maximum=0.99,
value=0.95,
label="Top p"
)
do_sample = Checkbox(
value=True,
label="Do sample"
)
temperature = Slider(
minimum=0,
maximum=1,
value=0.5,
label="Temperature"
)
with gr.Column(variant="panel"):
output = Textbox(
lines=11,
label="El mestre diu...",
interactive=False,
show_copy_button=True
)
with gr.Row(variant="panel"):
clear_btn = Button(
"Clear",
)
submit_btn = Button(
"Submit",
variant="primary",
interactive=False
)
with gr.Row():
with gr.Column(scale=0.5):
gr.Examples(
label="Short prompts:",
examples=[
["""La capital de Suècia"""],
],
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
outputs=output,
fn=submit_input,
)
gr.Examples(
label="Zero-shot prompts",
examples=[
["Tradueix del Castellà al Català la següent frase: \"Eso es pan comido.\" \nTraducció:"],
],
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
outputs=output,
fn=submit_input,
)
gr.Examples(
label="Few-Shot prompts:",
examples=[
["""Oració: Els sons melòdics produeixen una sensació de calma i benestar en l'individu. \nParàfrasi: La música és molt relaxant i reconfortant.\n----\nOració: L'animal domèstic mostra una gran alegria i satisfacció. \nParàfrasi: El gos és molt feliç. \n----\nOració: El vehicle es va trencar i vaig haver de contactar amb el servei de remolc perquè el transportés. \nParàfrasi: El cotxe es va trencar i vaig haver de trucar la grua. \n----\nOració: El professor va explicar els conceptes de manera clara i concisa. \nParàfrasi:"""],
],
inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
outputs=output,
fn=submit_input,
)
input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False)
input_.change(fn=None, inputs=[input_], api_name=False, js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """)
input_.change(fn=None, inputs=[input_, placeholder_max_token], api_name=False, js="""(i, m) => {
document.getElementById('inputlenght').textContent = i.length + ' '
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
}""")
clear_btn.click(fn=clear, inputs=[], outputs=[input_, output, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], queue=False, api_name=False)
submit_btn.click(fn=submit_input, inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], outputs=[output], api_name="get-results")
demo.launch(show_api=True)
if __name__ == "__main__":
gradio_app()