File size: 8,819 Bytes
6cba6d8
 
 
 
4c8718d
db29093
73c5116
6cba6d8
 
 
db29093
73c5116
 
 
 
 
 
 
 
 
db29093
 
 
 
 
 
 
6cba6d8
 
 
 
 
db29093
6cba6d8
 
 
db29093
6cba6d8
 
 
 
 
 
 
 
 
db29093
6cba6d8
 
 
 
 
 
73c5116
 
6cba6d8
 
 
 
 
 
db29093
 
 
 
 
 
6cba6d8
 
 
4c8718d
6cba6d8
 
 
 
 
ce62d0f
6cba6d8
ce62d0f
6cba6d8
 
 
4c8718d
6cba6d8
 
 
 
 
 
 
 
ce62d0f
6cba6d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce62d0f
6cba6d8
 
 
 
 
 
 
 
 
 
73c5116
6cba6d8
 
 
 
 
 
 
 
 
db29093
6cba6d8
 
 
 
 
 
 
 
 
db29093
6cba6d8
 
 
 
 
 
 
 
db29093
6cba6d8
 
 
 
 
 
db51483
6cba6d8
db51483
6cba6d8
db51483
6cba6d8
 
 
 
db29093
 
6cba6d8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
from dotenv import load_dotenv
import gradio as gr
from gradio.components import Textbox, Button, Slider, Checkbox
from AinaTheme import theme
from huggingface_hub import InferenceClient
from urllib.error import HTTPError

load_dotenv()

def generate(prompt, model_parameters):

    try:
        output = client.text_generation(prompt, **model_parameters, return_full_text=True)
        return output
    except HTTPError as err:
        if err.code == 400:
            gr.Warning("The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET.")
    except:
        gr.Warning('Inference endpoint is not available right now. Please try again later.')        


client = InferenceClient(
    os.environ.get("HF_INFERENCE_ENDPOINT_URL"),
    token=os.environ.get("HF_INFERENCE_ENDPOINT_TOKEN")
)

MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100))
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=100))
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True) == "True"


def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
    if input_.strip() == "":  
        gr.Warning('Not possible to inference an empty input')
        return None
    
    model_parameters = {
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "top_p": top_p,
        "do_sample": do_sample,
        "temperature": temperature
    }
     
    output = generate(input_, model_parameters)
   
    return output 
    
def change_interactive(text):
    if len(text.strip()) > MAX_INPUT_CHARACTERS:
        return gr.update(interactive = True),  gr.update(interactive = False)
    if (len(text) == 0):
        return gr.update(interactive = True),  gr.update(interactive = False)
    return gr.update(interactive = True),  gr.update(interactive = True)
    
def clear(): 
    return (
        None, 
        None,
        gr.update(value=MAX_NEW_TOKENS),
        gr.update(value=1.2),
        gr.update(value=50),
        gr.update(value=0.95),
        gr.update(value=True),
        gr.update(value=0.5),
    )
    
def gradio_app():
    with gr.Blocks(theme=theme) as demo:
        with gr.Row():
            with gr.Column(scale=0.1):
                gr.Image("ginesta_small.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
            with gr.Column():
                gr.Markdown(
                    """# AIMestre
                   
                    Basat en el model [Flor](https://huggingface.co/projecte-aina/FLOR-6.3B) del projecte AINA.
                 
                    """
                )
        with gr.Row(equal_height=True):
            with gr.Column(variant="panel"):
                placeholder_max_token = Textbox(
                    visible=False,
                    interactive=False,
                    value= MAX_INPUT_CHARACTERS
                )
                input_ = Textbox(
                    lines=11,
                    label="Posa aquí el teu escrit en català.",
                    placeholder="e.g. El mercat del barri és fantàstic hi pots trobar."
                )
                with gr.Row(variant="panel", equal_height=True):
                    gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""")
                    gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span>&nbsp;/&nbsp;{MAX_INPUT_CHARACTERS}</span>""")

                with gr.Row(variant="panel"):
                    with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
                        max_new_tokens = Slider(
                            minimum=1,
                            maximum=200,
                            step=1,
                            value=MAX_NEW_TOKENS,
                            label="Max tokens"
                        )
                        repetition_penalty = Slider(
                            minimum=0.1,
                            maximum=10,
                            step=0.1,
                            value=1.2,
                            label="Repetition penalty"
                        )
                        top_k = Slider(
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=50,
                            label="Top k"
                        )
                        top_p = Slider(
                            minimum=0.01,
                            maximum=0.99,
                            value=0.95,
                            label="Top p"
                        )  
                        do_sample = Checkbox(
                            value=True, 
                            label="Do sample"
                        )
                        temperature = Slider(
                            minimum=0, 
                            maximum=1,
                            value=0.5,
                            label="Temperature"
                        )
            with gr.Column(variant="panel"):
                output = Textbox(
                    lines=11,
                    label="El mestre diu...",
                    interactive=False,
                    show_copy_button=True
                )
                with gr.Row(variant="panel"):
                    clear_btn = Button(
                        "Clear", 
                    )
                    submit_btn = Button(
                        "Submit", 
                        variant="primary",
                        interactive=False
                    )

        with gr.Row():
            with gr.Column(scale=0.5):
                gr.Examples(
                    label="Short prompts:",
                    examples=[
                        ["""La capital de Suècia"""],
                    ],
                    inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
                    outputs=output,
                    fn=submit_input,
                )

                gr.Examples(
                    label="Zero-shot prompts",
                    examples=[
                        ["Tradueix del Castellà al Català la següent frase: \"Eso es pan comido.\" \nTraducció:"],
                    ],
                    inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
                    outputs=output,
                    fn=submit_input,
                )
                gr.Examples(
                    label="Few-Shot prompts:",
                    examples=[
                        ["""Oració: Els sons melòdics produeixen una sensació de calma i benestar en l'individu. \nParàfrasi: La música és molt relaxant i reconfortant.\n----\nOració: L'animal domèstic mostra una gran alegria i satisfacció. \nParàfrasi: El gos és molt feliç. \n----\nOració: El vehicle es va trencar i vaig haver de contactar amb el servei de remolc perquè el transportés. \nParàfrasi: El cotxe es va trencar i vaig haver de trucar la grua. \n----\nOració: El professor va explicar els conceptes de manera clara i concisa. \nParàfrasi:"""],
                    ],
                    inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature],
                    outputs=output,
                    fn=submit_input,
                )

            
           
        input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False)
      
        input_.change(fn=None, inputs=[input_], api_name=False, js=f"""(i) => document.getElementById('countertext').textContent =  i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """)

        input_.change(fn=None, inputs=[input_, placeholder_max_token], api_name=False, js="""(i, m) => {
            document.getElementById('inputlenght').textContent = i.length + '  '
            document.getElementById('inputlenght').style.color =  (i.length > m) ? "#ef4444" : "";
        }""")
    
        clear_btn.click(fn=clear, inputs=[], outputs=[input_, output, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], queue=False, api_name=False)
        submit_btn.click(fn=submit_input, inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], outputs=[output], api_name="get-results")

        demo.launch(show_api=True)

if __name__ == "__main__":
    gradio_app()