crodri commited on
Commit
6fb54a3
1 Parent(s): eb7d984

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import gradio as gr
4
+ from gradio.components import Textbox, Button, Slider
5
+ from AinaTheme import AinaGradioTheme
6
+
7
+ from meteocat_appv4 import generate
8
+
9
+ load_dotenv()
10
+
11
+
12
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True)
13
+ MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=200))
14
+
15
+ def submit_input(input_, repetition_penalty, temperature):
16
+ outputs = generate(input_, repetition_penalty, temperature)
17
+ if outputs is None:
18
+ gr.Warning("""
19
+ És possible que no hagi trobat el lloc o la data (de dijous a dilluns).
20
+ Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
21
+ """)
22
+ return "", "", ""
23
+
24
+ print(outputs)
25
+ print(outputs["model_answer"], outputs["context"], outputs["ccma_response"])
26
+ return outputs["model_answer"], outputs["context"], outputs["ccma_response"]
27
+
28
+
29
+ def change_interactive(text):
30
+ input_state = text
31
+ intput_length = len(input_state.strip())
32
+ if intput_length > MAX_NEW_TOKENS :
33
+ return gr.update(interactive = True), gr.update(interactive = False)
34
+ elif input_state.strip() != "":
35
+ return gr.update(interactive = True), gr.update(interactive = True)
36
+ else:
37
+ return gr.update(interactive = False), gr.update(interactive = False)
38
+
39
+ def clean():
40
+ return (
41
+ None,
42
+ None,
43
+ None,
44
+ None,
45
+ gr.Slider.update(value=1.0),
46
+ gr.Slider.update(value=1.0),
47
+ )
48
+
49
+
50
+ with gr.Blocks(**AinaGradioTheme().get_kwargs()) as demo:
51
+ with gr.Row():
52
+ with gr.Column():
53
+ input_ = Textbox(
54
+ lines=11,
55
+ label="Input",
56
+ placeholder="e.g. Prompt example."
57
+ )
58
+ characters_counter = gr.Markdown(f"""<span id=counter> 0 / {MAX_NEW_TOKENS} </span>""")
59
+ with gr.Row():
60
+ clear_btn = Button(
61
+ "Clear",
62
+ interactive=False
63
+ )
64
+ submit_btn = Button(
65
+ "Submit",
66
+ variant="primary",
67
+ interactive=False
68
+ )
69
+ with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
70
+ repetition_penalty = Slider(
71
+ minimum=0.1,
72
+ maximum=10.0,
73
+ step=0.1,
74
+ value=0.85,
75
+ label="Repetition penalty"
76
+ )
77
+ temperature = Slider(
78
+ minimum=0.0,
79
+ maximum=2.0,
80
+ value=0.85,
81
+ label="Temperature"
82
+ )
83
+
84
+ with gr.Column():
85
+ output_answer = Textbox(
86
+ lines=9,
87
+ label="Model text",
88
+ interactive=False,
89
+ show_copy_button=True
90
+ )
91
+ output_context = Textbox(
92
+ lines=9,
93
+ label="Model context",
94
+ interactive=False,
95
+ show_copy_button=True
96
+ )
97
+ output_CCMA = Textbox(
98
+ lines=9,
99
+ label="CCMA text",
100
+ interactive=False,
101
+ show_copy_button=True
102
+ )
103
+
104
+
105
+ input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn])
106
+
107
+
108
+ input_.change(fn=None, inputs=input_, _js=f"(i, m) => document.getElementById('counter').textContent = i.length + ' /' + {MAX_NEW_TOKENS}")
109
+
110
+
111
+
112
+ clear_btn.click(fn=clean, inputs=[], outputs=[input_, output_answer, output_context, output_CCMA, repetition_penalty, temperature], queue=False)
113
+ submit_btn.click(fn=submit_input, inputs=[input_, repetition_penalty, temperature], outputs=[output_answer, output_context, output_CCMA])
114
+
115
+
116
+
117
+ demo.queue(concurrency_count=1, api_open=False)
118
+ demo.launch(show_api=True, share=True, debug=True, server_name="84.88.187.178")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://gitlab.bsc.es/projecte-aina/aina-gradio-theme.git@1.3.6
2
+ transformers==4.34.0
3
+ gradio==3.48.0
4
+ python-dotenv==1.0.0
5
+ pymongo==4.5.0
6
+ boto3==1.28.64
7
+ torch==2.1.0