PhotographerAlpha7 commited on
Commit
dda567b
·
verified ·
1 Parent(s): 1fc1892

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -38
app.py CHANGED
@@ -7,10 +7,12 @@ import os
7
  import matplotlib.pyplot as plt
8
  import json
9
  import io
 
10
 
11
  # Variables globales pour stocker les colonnes détectées
12
  columns = []
13
 
 
14
  # Fonction pour lire le fichier et détecter les colonnes
15
  def read_file(data_file):
16
  global columns
@@ -25,16 +27,28 @@ def read_file(data_file):
25
  df = pd.read_excel(data_file.name)
26
  else:
27
  return "Invalid file format. Please upload a CSV, JSON, or Excel file."
28
-
29
  # Détecter les colonnes
30
  columns = df.columns.tolist()
31
  return columns
32
  except Exception as e:
33
  return f"An error occurred: {str(e)}"
34
 
 
 
 
 
 
 
 
 
35
  # Fonction pour entraîner le modèle
36
  def train_model(data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col):
37
  try:
 
 
 
 
38
  # Charger les données
39
  file_extension = os.path.splitext(data_file.name)[1]
40
  if file_extension == '.csv':
@@ -43,31 +57,31 @@ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output
43
  df = pd.read_json(data_file.name)
44
  elif file_extension == '.xlsx':
45
  df = pd.read_excel(data_file.name)
46
-
47
  # Prévisualisation des données
48
  preview = df.head().to_string(index=False)
49
-
50
  # Préparer le texte d'entraînement
51
  df['text'] = df[prompt_col] + ': ' + df[description_col]
52
  dataset = Dataset.from_pandas(df[['text']])
53
-
54
  # Initialiser le tokenizer et le modèle GPT-2
55
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
56
  model = GPT2LMHeadModel.from_pretrained(model_name)
57
-
58
  # Ajouter un token de padding si nécessaire
59
  if tokenizer.pad_token is None:
60
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
61
  model.resize_token_embeddings(len(tokenizer))
62
-
63
  # Tokenizer les données
64
  def tokenize_function(examples):
65
  tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
66
  tokens['labels'] = tokens['input_ids'].copy()
67
  return tokens
68
-
69
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
70
-
71
  # Ajustement des hyperparamètres
72
  training_args = TrainingArguments(
73
  output_dir=output_dir,
@@ -87,7 +101,7 @@ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output
87
  load_best_model_at_end=True,
88
  metric_for_best_model="eval_loss"
89
  )
90
-
91
  # Configuration du Trainer
92
  trainer = Trainer(
93
  model=model,
@@ -95,15 +109,15 @@ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output
95
  train_dataset=tokenized_datasets,
96
  eval_dataset=tokenized_datasets,
97
  )
98
-
99
  # Entraînement et évaluation
100
  trainer.train()
101
  eval_results = trainer.evaluate()
102
-
103
  # Sauvegarder le modèle fine-tuné
104
  model.save_pretrained(output_dir)
105
  tokenizer.save_pretrained(output_dir)
106
-
107
  # Générer un graphique des pertes d'entraînement et de validation
108
  train_loss = [x['loss'] for x in trainer.state.log_history if 'loss' in x]
109
  eval_loss = [x['eval_loss'] for x in trainer.state.log_history if 'eval_loss' in x]
@@ -114,37 +128,41 @@ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output
114
  plt.title('Training and Validation Loss')
115
  plt.legend()
116
  plt.savefig(os.path.join(output_dir, 'training_eval_loss.png'))
117
-
118
  return f"Training completed successfully.\nPreview of data:\n{preview}", eval_results
119
  except Exception as e:
120
  return f"An error occurred: {str(e)}"
121
 
 
122
  # Fonction de génération de texte
123
- def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma):
124
  try:
125
  model_name = "./fine-tuned-gpt2"
126
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
127
  model = GPT2LMHeadModel.from_pretrained(model_name)
128
-
129
  if use_comma:
130
  prompt = prompt.replace('.', ',')
131
-
132
  inputs = tokenizer(prompt, return_tensors="pt", padding=True)
133
  attention_mask = inputs.attention_mask
134
  outputs = model.generate(
135
- inputs.input_ids,
136
  attention_mask=attention_mask,
137
- max_length=int(max_length),
138
- temperature=float(temperature),
139
- top_k=int(top_k),
 
140
  repetition_penalty=float(repetition_penalty),
141
- num_return_sequences=1,
142
  pad_token_id=tokenizer.eos_token_id
143
  )
144
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
145
  except Exception as e:
146
  return f"An error occurred: {str(e)}"
147
 
 
148
  # Fonction pour configurer les presets
149
  def set_preset(preset):
150
  if preset == "Default":
@@ -154,52 +172,59 @@ def set_preset(preset):
154
  elif preset == "High Accuracy":
155
  return 10, 4, 1e-5
156
 
 
157
  # Interface Gradio
158
  with gr.Blocks() as ui:
159
- gr.Markdown("# Model-Fine-Tuner | by Dimonapatrick243")
160
-
161
  with gr.Tab("Train Model"):
162
  with gr.Row():
163
  data_file = gr.File(label="Upload Data File (CSV, JSON, Excel)")
164
  model_name = gr.Textbox(label="Model Name", value="gpt2")
165
  output_dir = gr.Textbox(label="Output Directory", value="./fine-tuned-gpt2")
166
-
167
  with gr.Row():
168
  preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
169
  epochs = gr.Number(label="Epochs", value=5)
170
  batch_size = gr.Number(label="Batch Size", value=8)
171
  learning_rate = gr.Number(label="Learning Rate", value=3e-5)
172
-
173
  preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
174
-
175
  # Champs pour sélectionner les colonnes
176
  with gr.Row():
177
- design_col = gr.Dropdown(label="Design Column")
178
  description_col = gr.Dropdown(label="Description Column")
179
-
180
  # Détection des colonnes lors du téléchargement du fichier
181
- data_file.upload(read_file, inputs=data_file, outputs=[design_col, description_col])
182
-
183
  train_button = gr.Button("Train Model")
184
  train_output = gr.Textbox(label="Training Output")
185
  train_graph = gr.Image(label="Training and Validation Loss Graph")
186
-
187
- train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate, output_dir, design_col, description_col], outputs=[train_output, train_graph])
188
-
 
 
189
  with gr.Tab("Generate Text"):
190
  with gr.Row():
191
  with gr.Column():
192
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
193
  top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
 
194
  max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
195
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
196
  use_comma = gr.Checkbox(label="Use Comma", value=True)
197
-
 
198
  with gr.Column():
199
  prompt = gr.Textbox(label="Prompt")
200
  generate_button = gr.Button("Generate Text")
201
- generated_text = gr.Textbox(label="Generated Text")
202
-
203
- generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text)
 
 
204
 
205
  ui.launch()
 
7
  import matplotlib.pyplot as plt
8
  import json
9
  import io
10
+ from datetime import datetime
11
 
12
  # Variables globales pour stocker les colonnes détectées
13
  columns = []
14
 
15
+
16
  # Fonction pour lire le fichier et détecter les colonnes
17
  def read_file(data_file):
18
  global columns
 
27
  df = pd.read_excel(data_file.name)
28
  else:
29
  return "Invalid file format. Please upload a CSV, JSON, or Excel file."
30
+
31
  # Détecter les colonnes
32
  columns = df.columns.tolist()
33
  return columns
34
  except Exception as e:
35
  return f"An error occurred: {str(e)}"
36
 
37
+
38
+ # Fonction pour valider les colonnes sélectionnées
39
+ def validate_columns(prompt_col, description_col):
40
+ if prompt_col not in columns or description_col not in columns:
41
+ return False
42
+ return True
43
+
44
+
45
  # Fonction pour entraîner le modèle
46
  def train_model(data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col):
47
  try:
48
+ # Valider les colonnes sélectionnées
49
+ if not validate_columns(prompt_col, description_col):
50
+ return "Invalid column selection. Please ensure the columns exist in the dataset."
51
+
52
  # Charger les données
53
  file_extension = os.path.splitext(data_file.name)[1]
54
  if file_extension == '.csv':
 
57
  df = pd.read_json(data_file.name)
58
  elif file_extension == '.xlsx':
59
  df = pd.read_excel(data_file.name)
60
+
61
  # Prévisualisation des données
62
  preview = df.head().to_string(index=False)
63
+
64
  # Préparer le texte d'entraînement
65
  df['text'] = df[prompt_col] + ': ' + df[description_col]
66
  dataset = Dataset.from_pandas(df[['text']])
67
+
68
  # Initialiser le tokenizer et le modèle GPT-2
69
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
70
  model = GPT2LMHeadModel.from_pretrained(model_name)
71
+
72
  # Ajouter un token de padding si nécessaire
73
  if tokenizer.pad_token is None:
74
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
75
  model.resize_token_embeddings(len(tokenizer))
76
+
77
  # Tokenizer les données
78
  def tokenize_function(examples):
79
  tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
80
  tokens['labels'] = tokens['input_ids'].copy()
81
  return tokens
82
+
83
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
84
+
85
  # Ajustement des hyperparamètres
86
  training_args = TrainingArguments(
87
  output_dir=output_dir,
 
101
  load_best_model_at_end=True,
102
  metric_for_best_model="eval_loss"
103
  )
104
+
105
  # Configuration du Trainer
106
  trainer = Trainer(
107
  model=model,
 
109
  train_dataset=tokenized_datasets,
110
  eval_dataset=tokenized_datasets,
111
  )
112
+
113
  # Entraînement et évaluation
114
  trainer.train()
115
  eval_results = trainer.evaluate()
116
+
117
  # Sauvegarder le modèle fine-tuné
118
  model.save_pretrained(output_dir)
119
  tokenizer.save_pretrained(output_dir)
120
+
121
  # Générer un graphique des pertes d'entraînement et de validation
122
  train_loss = [x['loss'] for x in trainer.state.log_history if 'loss' in x]
123
  eval_loss = [x['eval_loss'] for x in trainer.state.log_history if 'eval_loss' in x]
 
128
  plt.title('Training and Validation Loss')
129
  plt.legend()
130
  plt.savefig(os.path.join(output_dir, 'training_eval_loss.png'))
131
+
132
  return f"Training completed successfully.\nPreview of data:\n{preview}", eval_results
133
  except Exception as e:
134
  return f"An error occurred: {str(e)}"
135
 
136
+
137
  # Fonction de génération de texte
138
+ def generate_text(prompt, temperature, top_k, top_p, max_length, repetition_penalty, use_comma, batch_size):
139
  try:
140
  model_name = "./fine-tuned-gpt2"
141
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
142
  model = GPT2LMHeadModel.from_pretrained(model_name)
143
+
144
  if use_comma:
145
  prompt = prompt.replace('.', ',')
146
+
147
  inputs = tokenizer(prompt, return_tensors="pt", padding=True)
148
  attention_mask = inputs.attention_mask
149
  outputs = model.generate(
150
+ inputs.input_ids,
151
  attention_mask=attention_mask,
152
+ max_length=int(max_length),
153
+ temperature=float(temperature),
154
+ top_k=int(top_k),
155
+ top_p=float(top_p),
156
  repetition_penalty=float(repetition_penalty),
157
+ num_return_sequences=int(batch_size),
158
  pad_token_id=tokenizer.eos_token_id
159
  )
160
+
161
+ return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
162
  except Exception as e:
163
  return f"An error occurred: {str(e)}"
164
 
165
+
166
  # Fonction pour configurer les presets
167
  def set_preset(preset):
168
  if preset == "Default":
 
172
  elif preset == "High Accuracy":
173
  return 10, 4, 1e-5
174
 
175
+
176
  # Interface Gradio
177
  with gr.Blocks() as ui:
178
+ gr.Markdown("# Fine-Tune GPT-2 UI Design Model")
179
+
180
  with gr.Tab("Train Model"):
181
  with gr.Row():
182
  data_file = gr.File(label="Upload Data File (CSV, JSON, Excel)")
183
  model_name = gr.Textbox(label="Model Name", value="gpt2")
184
  output_dir = gr.Textbox(label="Output Directory", value="./fine-tuned-gpt2")
185
+
186
  with gr.Row():
187
  preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
188
  epochs = gr.Number(label="Epochs", value=5)
189
  batch_size = gr.Number(label="Batch Size", value=8)
190
  learning_rate = gr.Number(label="Learning Rate", value=3e-5)
191
+
192
  preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
193
+
194
  # Champs pour sélectionner les colonnes
195
  with gr.Row():
196
+ prompt_col = gr.Dropdown(label="Prompt Column")
197
  description_col = gr.Dropdown(label="Description Column")
198
+
199
  # Détection des colonnes lors du téléchargement du fichier
200
+ data_file.upload(read_file, inputs=data_file, outputs=[prompt_col, description_col])
201
+
202
  train_button = gr.Button("Train Model")
203
  train_output = gr.Textbox(label="Training Output")
204
  train_graph = gr.Image(label="Training and Validation Loss Graph")
205
+
206
+ train_button.click(train_model,
207
+ inputs=[data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col,
208
+ description_col], outputs=[train_output, train_graph])
209
+
210
  with gr.Tab("Generate Text"):
211
  with gr.Row():
212
  with gr.Column():
213
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
214
  top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
215
+ top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9)
216
  max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
217
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
218
  use_comma = gr.Checkbox(label="Use Comma", value=True)
219
+ batch_size = gr.Number(label="Batch Size", value=1, minimum=1)
220
+
221
  with gr.Column():
222
  prompt = gr.Textbox(label="Prompt")
223
  generate_button = gr.Button("Generate Text")
224
+ generated_text = gr.Textbox(label="Generated Text", lines=20)
225
+
226
+ generate_button.click(generate_text,
227
+ inputs=[prompt, temperature, top_k, top_p, max_length, repetition_penalty, use_comma,
228
+ batch_size], outputs=generated_text)
229
 
230
  ui.launch()