BlooMeteo / app.py
crodri's picture
Upload 4 files
6fb54a3
raw
history blame
No virus
4.07 kB
import os
from dotenv import load_dotenv
import gradio as gr
from gradio.components import Textbox, Button, Slider
from AinaTheme import AinaGradioTheme
from meteocat_appv4 import generate
load_dotenv()
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True)
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=200))
def submit_input(input_, repetition_penalty, temperature):
outputs = generate(input_, repetition_penalty, temperature)
if outputs is None:
gr.Warning("""
És possible que no hagi trobat el lloc o la data (de dijous a dilluns).
Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
""")
return "", "", ""
print(outputs)
print(outputs["model_answer"], outputs["context"], outputs["ccma_response"])
return outputs["model_answer"], outputs["context"], outputs["ccma_response"]
def change_interactive(text):
input_state = text
intput_length = len(input_state.strip())
if intput_length > MAX_NEW_TOKENS :
return gr.update(interactive = True), gr.update(interactive = False)
elif input_state.strip() != "":
return gr.update(interactive = True), gr.update(interactive = True)
else:
return gr.update(interactive = False), gr.update(interactive = False)
def clean():
return (
None,
None,
None,
None,
gr.Slider.update(value=1.0),
gr.Slider.update(value=1.0),
)
with gr.Blocks(**AinaGradioTheme().get_kwargs()) as demo:
with gr.Row():
with gr.Column():
input_ = Textbox(
lines=11,
label="Input",
placeholder="e.g. Prompt example."
)
characters_counter = gr.Markdown(f"""<span id=counter> 0 / {MAX_NEW_TOKENS} </span>""")
with gr.Row():
clear_btn = Button(
"Clear",
interactive=False
)
submit_btn = Button(
"Submit",
variant="primary",
interactive=False
)
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
repetition_penalty = Slider(
minimum=0.1,
maximum=10.0,
step=0.1,
value=0.85,
label="Repetition penalty"
)
temperature = Slider(
minimum=0.0,
maximum=2.0,
value=0.85,
label="Temperature"
)
with gr.Column():
output_answer = Textbox(
lines=9,
label="Model text",
interactive=False,
show_copy_button=True
)
output_context = Textbox(
lines=9,
label="Model context",
interactive=False,
show_copy_button=True
)
output_CCMA = Textbox(
lines=9,
label="CCMA text",
interactive=False,
show_copy_button=True
)
input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn])
input_.change(fn=None, inputs=input_, _js=f"(i, m) => document.getElementById('counter').textContent = i.length + ' /' + {MAX_NEW_TOKENS}")
clear_btn.click(fn=clean, inputs=[], outputs=[input_, output_answer, output_context, output_CCMA, repetition_penalty, temperature], queue=False)
submit_btn.click(fn=submit_input, inputs=[input_, repetition_penalty, temperature], outputs=[output_answer, output_context, output_CCMA])
demo.queue(concurrency_count=1, api_open=False)
demo.launch(show_api=True, share=True, debug=True, server_name="84.88.187.178")