tykiww commited on
Commit
c7759ea
·
verified ·
1 Parent(s): 7f8a5b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -190
app.py CHANGED
@@ -1,205 +1,68 @@
1
- ##################################### Imports ######################################
2
- # Generic imports
3
- import spaces
4
- import gradio as gr
5
- import json
6
-
7
- # Specialized imports
8
- #from utilities.modeling import modeling
9
 
10
- # server import
11
- from server import submit_weights #, train_model, submit_weights
 
 
12
 
13
 
14
- # Module imports
15
- from utilities.setup import get_files
16
- from utilities.templates import prompt_template
 
 
 
17
 
18
- ########################### Global objects and functions ###########################
19
 
20
- conf = get_files.json_cfg()
 
 
 
 
21
 
22
- class update_visibility:
23
 
24
- def textbox_vis(radio):
25
- value = radio
26
- if value == "Hugging Face Hub Dataset":
27
- return gr.Dropdown(visible=bool(1))
28
- else:
29
- return gr.Dropdown(visible=bool(0))
30
 
31
- def textbox_button_vis(radio):
32
- value = radio
33
- if value == "Hugging Face Hub Dataset":
34
- return gr.Button(visible=bool(1))
35
- else:
36
- return gr.Button(visible=bool(0))
37
 
38
- def upload_vis(radio):
39
- value = radio
40
- if value == "Upload Your Own":
41
- return gr.UploadButton(visible=bool(1)) #make it visible
42
- else:
43
- return gr.UploadButton(visible=bool(0))
44
- @spaces.GPU
45
- def train(model_name,
46
- inject_prompt,
47
- dataset_predefined,
48
- peft,
49
- sft,
50
- max_seq_length,
51
- random_seed,
52
- num_epochs,
53
- max_steps,
54
- data_field,
55
- repository,
56
- model_out_name):
57
- """The model call"""
58
-
59
- # Get models
60
- # trainer = modeling(model_name, max_seq_length, random_seed,
61
- # peft, sft, dataset, data_field)
62
- # trainer_stats = trainer.train()
63
-
64
- # Return outputs of training.
65
 
66
- return f"Hello!! Using model: {model_name} with template: {inject_prompt}"
67
-
68
-
69
-
70
- ##################################### App UI #######################################
71
-
72
-
73
-
74
- def main():
75
- with gr.Blocks() as demo:
76
-
77
- with gr.Tabs():
78
- with gr.TabItem("About"):
79
- # About page!!
80
- gr.Markdown(get_files.load_markdown_file("README.md"))
81
-
82
- with gr.TabItem("Basic Setup"):
83
- gr.Markdown("# Select Model and Input details")
84
- # Select Model
85
- modelnames = conf['model']['choices']
86
- model_name = gr.Dropdown(label="Supported Models",
87
- choices=modelnames,
88
- value=modelnames[0])
89
- # Select Generic Model parameters
90
- repository = gr.Textbox(label="Your User Name",
91
- value=conf['model']['general']["repository"])
92
- model_out_name = gr.Textbox(label="Your Model Output Name",
93
- value=conf['model']['general']["model_name"])
94
- hf_token = gr.Textbox(label="Your Huggingface Token",
95
- type='password',
96
- value='')
97
-
98
-
99
- with gr.TabItem("Upload Data"):
100
- # Toggle dataset load types
101
- gr.Markdown("# Dataset Selection and Upload")
102
-
103
- dataset_choice = gr.Radio(label="Choose Dataset",
104
- choices=["Hugging Face Hub Dataset", "Upload Your Own"],
105
- value="Hugging Face Hub Dataset")
106
- dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset",
107
- value='yahma/alpaca-cleaned',
108
- visible=True)
109
- dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
110
-
111
- dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)",
112
- file_types=[".csv",".jsonl", ".txt"],
113
- visible=False)
114
- # Safety output to show if upload succeeded.
115
- data_snippet = gr.Markdown()
116
-
117
- # Visibility toggler
118
- dataset_choice.change(update_visibility.textbox_vis,
119
- dataset_choice,
120
- dataset_predefined)
121
- dataset_choice.change(update_visibility.upload_vis,
122
- dataset_choice,
123
- dataset_uploaded_load)
124
- dataset_choice.change(update_visibility.textbox_button_vis,
125
- dataset_choice,
126
- dataset_predefined_load)
127
- # Prompt template
128
- inject_prompt = gr.Textbox(label="Prompt Template",
129
- value=prompt_template())
130
- # Dataset buttons
131
- dataset_predefined_load.click(fn=get_files.predefined_dataset,
132
- inputs=dataset_predefined,
133
- outputs=data_snippet)
134
-
135
- dataset_uploaded_load.click(fn=get_files.uploaded_dataset,
136
- inputs=dataset_uploaded_load,
137
- outputs=data_snippet)
138
-
139
-
140
- with gr.TabItem("Train Model"):
141
- ##### Model Parameter Inputs #####
142
- gr.Markdown("# Model Parameter Selection")
143
-
144
- # Parameters
145
- data_field = gr.Textbox(label="Dataset Training Field Name",
146
- value=conf['model']['general']["dataset_text_field"])
147
- max_seq_length = gr.Textbox(label="Maximum sequence length",
148
- value=conf['model']['general']["max_seq_length"])
149
- random_seed = gr.Textbox(label="Seed",
150
- value=conf['model']['general']["seed"])
151
- num_epochs = gr.Textbox(label="Training Epochs",
152
- value=conf['model']['general']["num_train_epochs"])
153
- max_steps = gr.Textbox(label="Maximum steps",
154
- value=conf['model']['general']["max_steps"])
155
-
156
- # Hyperparameters (allow selection, but hide in accordion.)
157
- with gr.Accordion("Advanced Tuning", open=False):
158
-
159
- sftparams = conf['model']['general']
160
- # accordion container content
161
- dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
162
- peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
163
-
164
- dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
165
- sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)
166
-
167
- ##### Execution #####
168
-
169
- # Setup buttons
170
- tune_btn = gr.Button("Start Fine Tuning")
171
- gr.Markdown("### Model Progress")
172
- # Text output (for now)
173
- output = gr.Textbox(label="Output")
174
-
175
-
176
- # Data retrieval
177
-
178
-
179
- # Execute buttons
180
- tune_btn.click(fn=train,
181
- inputs=[model_name,
182
- inject_prompt,
183
- dataset_predefined,
184
- peft,
185
- sft,
186
- max_seq_length,
187
- random_seed,
188
- num_epochs,
189
- max_steps,
190
- data_field,
191
- repository,
192
- model_out_name
193
- ],
194
- outputs=output)
195
- # stop button
196
 
197
- # submit button
198
-
199
- # Launch baby
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  demo.launch()
201
 
202
- ##################################### Launch #######################################
203
 
 
204
  if __name__ == "__main__":
205
- main()
 
 
 
1
+ ###########################
2
+ # UI for Meeting RAG Q&A. #
3
+ ###########################
 
 
 
 
 
4
 
5
+ ##################### Imports #####################
6
+ import gradio as gr
7
+ from utilities.setup import get_files
8
+ from server import EmbeddingService, QAService
9
 
10
 
11
+ #################### Functions ####################
12
+ def process_transcripts(files):
13
+ with EmbeddingService() as e:
14
+ e.run(files)
15
+ # some way to wait or a progress bar?
16
+ return 0
17
 
 
18
 
19
+ def retrieve_answer(question):
20
+ with QAService() as q:
21
+ q.run(question)
22
+ answer = retriever.answer()
23
+ return answer
24
 
 
25
 
26
+ ##################### Process #####################
27
+ def main(conf):
28
+ with gr.Blocks() as demo:
 
 
 
29
 
30
+ # Main page
31
+ with gr.TabItem(conf["layout"]["page_names"][0]):
32
+ gr.Markdown(get_files.load_markdown_file(conf["layout"]["About"]))
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # User config page
37
+ with gr.TabItem(conf["layout"]["page_names"][1]):
38
+ gr.Markdown("# Upload Transcript and Necessary Context")
39
+ gr.Markdown("Please wait as the transcript is being processed.")
40
+ load_file = gr.UploadButton(label="Upload Transcript (.vtt)",
41
+ file_types=[".vtt"])
42
+ goals = gr.Textbox(label="Goals for the Meeting",
43
+ value=conf["defaults"]["goals"]) # not incorporated yet. Will be with Q&A.
44
+ repository = gr.Textbox(label="Blank", visible=False) # since there is no output.
45
+ upload_button.upload(process_transcripts, load_file, repository)
46
+
47
+
48
+
49
+ # Meeting Question & Answer Page
50
+ with gr.TabItem(conf["layout"]["page_names"][2]):
51
+ question = gr.Textbox(label="Ask a Question",
52
+ value=conf["default"]["question"])
53
+ ask_button = gr.Button("Ask!")
54
+ model_output = gr.components.Textbox(label="Answer")
55
+ dataset_predefined_load.click(fn=retrieve_answer,
56
+ inputs=question,
57
+ outputs=model_output)
58
+
59
+
60
+
61
  demo.launch()
62
 
 
63
 
64
+ ##################### Execute #####################
65
  if __name__ == "__main__":
66
+ # Get config
67
+ conf = get_files.json_cfg()
68
+ main(conf)