File size: 9,804 Bytes
ed0dca2
a40632d
bdf3e70
38d7f73
 
2f1a468
 
1c03d71
 
bdf3e70
30da7cc
 
9e53c43
b1cf10f
30da7cc
b1cf10f
a40632d
 
aeb447f
a52308c
6664e37
aeb447f
 
 
 
 
 
6664e37
aeb447f
 
 
 
 
 
6664e37
aeb447f
 
 
 
 
a52308c
aeb447f
2f1a468
aeb447f
6ca76c1
 
 
aeb447f
 
02fc014
aeb447f
dfd7ad6
6d57565
5a49926
 
 
 
6ca76c1
5a49926
 
2f1a468
 
d8dc2cd
 
0e413cd
 
 
 
 
7b89069
 
d8dc2cd
2f1a468
 
 
 
 
 
 
 
 
 
 
 
30da7cc
2f1a468
 
 
 
 
 
 
 
42cd34a
bdf3e70
2f1a468
 
 
 
 
 
 
 
 
ed0dca2
52589e7
7b89069
 
17dfda2
 
7b89069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f830874
7b89069
 
 
 
 
44d732c
7b89069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02fc014
7b89069
 
 
ab084df
4b6fd72
17dfda2
7b89069
2f1a468
7b89069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1a468
7b89069
 
2f1a468
7b89069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1a468
7b89069
2f1a468
7b89069
 
 
 
17dfda2
 
52589e7
 
 
 
17dfda2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
##################################### Imports ######################################
# Generic imports
import gradio as gr
import json

# Specialized imports
#from utilities.modeling import modeling
from datasets import load_dataset


# Module imports
from utilities.setup import get_json_cfg
from utilities.templates import prompt_template

########################### Global objects and functions ###########################

conf = get_json_cfg()

class update_visibility:

    def textbox_vis(radio):
        value = radio
        if value == "Hugging Face Hub Dataset":
            return gr.Dropdown(visible=bool(1))
        else:
            return gr.Dropdown(visible=bool(0))
    
    def textbox_button_vis(radio):
        value = radio
        if value == "Hugging Face Hub Dataset":
            return gr.Button(visible=bool(1))
        else:
            return gr.Button(visible=bool(0))
    
    def upload_vis(radio):
        value = radio
        if value == "Upload Your Own":
            return gr.UploadButton(visible=bool(1)) #make it visible
        else:
            return gr.UploadButton(visible=bool(0))

class get_datasets:

    def predefined_dataset(dataset_name):
        global dataset # bad practice, I know... But just bear with me. Will later update to state dict.
        dataset = load_dataset(dataset_name, split = "train")
        return 'Successfully loaded dataset'
    
    def uploaded_dataset(file):
        global dataset # bad practice, I know... But just bear with me. Will later update to state dict.
        dataset = []
        if file is None:
            return "File not found. Please upload the file again."
        try:
            with open(file,'r') as file:
                for line in file:
                    dataset.append(json.loads(line.strip()))
            return "File retrieved."
        except FileNotFoundError:
            return "File not found. Please upload the file again."




def show_about():
    return "## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback."

def show_upload():
    return "## Upload\n\nUse the button below to upload your dataset."



def train(model_name, 
          inject_prompt, 
          dataset_predefined,
          peft,
          sft,
          max_seq_length,
          random_seed,
          num_epochs,
          max_steps,
          data_field,
          repository,
          model_out_name):
    """The model call"""

    # Get models
    # trainer = modeling(model_name, max_seq_length, random_seed,
    #                    peft, sft, dataset, data_field)
    # trainer_stats = trainer.train()

    # Return outputs of training.
    
    return f"Hello!! Using model: {model_name} with template: {inject_prompt}"


def submit_weights(model, repository, model_out_name, token):
    """submits model to repository"""
    repo = repository + '/' + model_out_name
    
    model.push_to_hub(repo, token = token)
    tokenizer.push_to_hub(repo, token = token)
    return 0

##################################### App UI #######################################



def main():
    with gr.Blocks() as demo:

        with gr.Row():
            with gr.Column(scale=1, min_width=200):  # Sidebar navigation
                gr.Markdown("### Navigation")
                btn_about = gr.Button("About")
                btn_upload = gr.Button("Upload Dataset")
            with gr.Column(scale=4):  # Main content area
                ##### Title Block #####
                gr.Markdown("# SLM Instruction Tuning with Unsloth")
            
                ##### Initial Model Inputs #####
                gr.Markdown("### Model Inputs")
                
                # Select Model
                modelnames = conf['model']['choices']
                model_name = gr.Dropdown(label="Supported Models", 
                                         choices=modelnames, 
                                         value=modelnames[0])
                # Prompt template
                inject_prompt = gr.Textbox(label="Prompt Template", 
                                             value=prompt_template())
                # Dataset choice
                dataset_choice = gr.Radio(label="Choose Dataset", 
                                          choices=["Hugging Face Hub Dataset", "Upload Your Own"], 
                                          value="Hugging Face Hub Dataset")
        
                dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset", 
                                                value='yahma/alpaca-cleaned', 
                                                visible=True)
                
                dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
        
                dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)", 
                                                 file_types=[".csv",".jsonl", ".txt"], 
                                                 visible=False)
                data_snippet = gr.Markdown()
                
                dataset_choice.change(update_visibility.textbox_vis, 
                                      dataset_choice, 
                                      dataset_predefined)
                dataset_choice.change(update_visibility.upload_vis, 
                                      dataset_choice,
                                      dataset_uploaded_load)
                dataset_choice.change(update_visibility.textbox_button_vis,
                                      dataset_choice,
                                      dataset_predefined_load)
                
                # Dataset button
                dataset_predefined_load.click(fn=get_datasets.predefined_dataset,
                                          inputs=dataset_predefined,
                                          outputs=data_snippet)
        
                dataset_uploaded_load.click(fn=get_datasets.uploaded_dataset,
                                         inputs=dataset_uploaded_load,
                                         outputs=data_snippet)
        
        
        
                
        
                ##### Model Parameter Inputs #####
                gr.Markdown("### Model Parameter Selection")
                # Parameters
                data_field = gr.Textbox(label="Dataset Training Field Name",
                                        value=conf['model']['general']["dataset_text_field"])
                max_seq_length = gr.Textbox(label="Maximum sequence length", 
                                             value=conf['model']['general']["max_seq_length"])
                random_seed = gr.Textbox(label="Seed",
                                        value=conf['model']['general']["seed"])
                num_epochs = gr.Textbox(label="Training Epochs",
                                        value=conf['model']['general']["num_train_epochs"])
                max_steps = gr.Textbox(label="Maximum steps",
                                        value=conf['model']['general']["max_steps"])   
                repository = gr.Textbox(label="Repository Name",
                                        value=conf['model']['general']["repository"])   
                model_out_name = gr.Textbox(label="Model Output Name",
                                        value=conf['model']['general']["model_name"])   
        
                # Hyperparameters (allow selection, but hide in accordion.)
                with gr.Accordion("Advanced Tuning", open=False):
        
                    sftparams = conf['model']['general']
                    # accordion container content
                    dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
                    peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
                    
                    dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
                    sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)            
                
                ##### Execution #####
            
                # Setup buttons
                tune_btn = gr.Button("Start Fine Tuning")
                gr.Markdown("### Model Progress")
                # Text output (for now)
                output = gr.Textbox(label="Output") 
                
                
                # Data retrieval
                
                
                # Execute buttons
                tune_btn.click(fn=train, 
                               inputs=[model_name, 
                                       inject_prompt, 
                                       dataset_predefined,
                                       peft,
                                       sft,
                                       max_seq_length,
                                       random_seed,
                                       num_epochs,
                                       max_steps,
                                       data_field,
                                       repository,
                                       model_out_name
                                      ],
                               outputs=output)
                # stop button
        
                # submit button
        
            # Link buttons to functions
        btn_about.click(fn=show_about, outputs=main_content)
        btn_upload.click(fn=lambda: ("", gr.update(visible=True)), outputs=[main_content, dataset_uploaded_load])

        # Launch baby
        demo.launch()

##################################### Launch #######################################

if __name__ == "__main__":
    main()