File size: 4,072 Bytes
6fb54a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from dotenv import load_dotenv
import gradio as gr
from gradio.components import Textbox, Button, Slider
from AinaTheme import AinaGradioTheme

from meteocat_appv4 import generate

load_dotenv()


SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True)
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=200))

def submit_input(input_, repetition_penalty, temperature): 
    outputs = generate(input_, repetition_penalty, temperature)
    if outputs is None:
        gr.Warning("""
                   És possible que no hagi trobat el lloc o la data (de dijous a dilluns). 
                   Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
                   """)
        return "", "", ""
    
    print(outputs)
    print(outputs["model_answer"], outputs["context"], outputs["ccma_response"])
    return outputs["model_answer"], outputs["context"], outputs["ccma_response"]

   
def change_interactive(text):
    input_state = text
    intput_length = len(input_state.strip())
    if intput_length > MAX_NEW_TOKENS :
        return gr.update(interactive = True), gr.update(interactive = False)
    elif input_state.strip() != "":
        return gr.update(interactive = True), gr.update(interactive = True)
    else:
        return gr.update(interactive = False), gr.update(interactive = False)

def clean(): 
    return (
        None, 
        None,
        None,
        None,
        gr.Slider.update(value=1.0),
        gr.Slider.update(value=1.0),
    )
       

with gr.Blocks(**AinaGradioTheme().get_kwargs()) as demo:
    with gr.Row():  
        with gr.Column():
            input_ = Textbox(
                lines=11, 
                label="Input",
                placeholder="e.g. Prompt example."
            )
            characters_counter = gr.Markdown(f"""<span id=counter> 0 / {MAX_NEW_TOKENS} </span>""")
            with gr.Row():
                clear_btn = Button(
                    "Clear", 
                    interactive=False
                )
                submit_btn = Button(
                    "Submit", 
                    variant="primary", 
                    interactive=False
                )
            with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
                repetition_penalty = Slider(
                    minimum=0.1, 
                    maximum=10.0, 
                    step=0.1, 
                    value=0.85, 
                    label="Repetition penalty"
                )
                temperature = Slider(
                    minimum=0.0, 
                    maximum=2.0, 
                    value=0.85, 
                    label="Temperature"
                )
            
        with gr.Column():
            output_answer = Textbox(
                lines=9, 
                label="Model text", 
                interactive=False, 
                show_copy_button=True
            )
            output_context = Textbox(
                lines=9, 
                label="Model context", 
                interactive=False, 
                show_copy_button=True
            )
            output_CCMA = Textbox(
                lines=9, 
                label="CCMA text", 
                interactive=False, 
                show_copy_button=True
            )
            
    
    input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn])
    

    input_.change(fn=None, inputs=input_, _js=f"(i, m) => document.getElementById('counter').textContent = i.length  + ' /' +  {MAX_NEW_TOKENS}")



    clear_btn.click(fn=clean, inputs=[], outputs=[input_, output_answer, output_context, output_CCMA,  repetition_penalty, temperature], queue=False)
    submit_btn.click(fn=submit_input, inputs=[input_, repetition_penalty, temperature], outputs=[output_answer, output_context, output_CCMA])


    
    demo.queue(concurrency_count=1, api_open=False)
    demo.launch(show_api=True, share=True, debug=True, server_name="84.88.187.178")