rmayormartins commited on
Commit
6cd28dd
·
1 Parent(s): e453a8d

Subindo arquivos

Browse files
Files changed (3) hide show
  1. README.md +14 -7
  2. app.py +355 -141
  3. requirements.txt +12 -6
README.md CHANGED
@@ -1,13 +1,20 @@
1
  ---
2
- title: Image Classifier Interactive
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
- license: ecl-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: interactive-image-classifier
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "4.12.0"
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # Image Enhancer
13
+
14
+ Upload an image (.jpg, .png) per class, follow the interactive process for image classification, train, evaluate, predict and export
15
+
16
+ ## Versão teste 1 (16/05)
17
+
18
+ - Ramon Mayor Martins
19
+ - E-mail: [rmayormartins@gmail.com](mailto:rmayormartins@gmail.com)
20
+
app.py CHANGED
@@ -1,146 +1,360 @@
1
  import gradio as gr
 
 
 
 
 
2
  import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
 
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
-
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
-
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
-
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
-
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
  }
51
- """
52
-
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
57
-
58
- with gr.Blocks(css=css) as demo:
59
-
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
-
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
139
-
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
144
- )
145
-
146
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import shutil
4
+ from sklearn.metrics import classification_report, confusion_matrix
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
  import numpy as np
8
+ import io
 
9
  import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torchvision import datasets, transforms, models
13
+ from torch.utils.data import DataLoader, random_split
14
+ from PIL import Image
15
+ import joblib # .pkl
16
 
17
+ #
18
+ model_dict = {
19
+ 'AlexNet': models.alexnet,
20
+ 'ResNet18': models.resnet18,
21
+ 'ResNet34': models.resnet34,
22
+ 'ResNet50': models.resnet50,
23
+ 'MobileNetV2': models.mobilenet_v2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
+
26
+ #
27
+ model = None
28
+ train_loader = None
29
+ val_loader = None
30
+ test_loader = None
31
+ dataset_path = 'dataset'
32
+ class_dirs = []
33
+ test_dataset_path = 'test_dataset'
34
+ test_class_dirs = []
35
+ num_classes = 2 #
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ #
39
+ def setup_classes(num_classes_value):
40
+ global class_dirs, dataset_path, num_classes
41
+
42
+ num_classes = int(num_classes_value) #
43
+ #
44
+ if os.path.exists(dataset_path):
45
+ shutil.rmtree(dataset_path)
46
+ os.makedirs(dataset_path)
47
+
48
+ #
49
+ class_dirs = [os.path.join(dataset_path, f'class_{i}') for i in range(num_classes)]
50
+ for class_dir in class_dirs:
51
+ os.makedirs(class_dir)
52
+
53
+ return f"Criados {num_classes} diretórios para classes."
54
+
55
+ #
56
+ def upload_images(class_id, images):
57
+ class_dir = class_dirs[int(class_id)]
58
+ for image in images:
59
+ shutil.copy(image, class_dir)
60
+ return f"Imagens salvas na classe {class_id}."
61
+
62
+ #
63
+ def prepare_data(batch_size=32, resize=(224, 224)):
64
+ global train_loader, val_loader, test_loader, num_classes
65
+
66
+ #
67
+ transform = transforms.Compose([
68
+ transforms.Resize(resize),
69
+ transforms.ToTensor(),
70
+ ])
71
+
72
+ dataset = datasets.ImageFolder(dataset_path, transform=transform)
73
+
74
+ if len(dataset.classes) != num_classes:
75
+ return f"Erro: Número de classes detectadas ({len(dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
76
+
77
+ #
78
+ train_size = int(0.7 * len(dataset))
79
+ val_size = int(0.2 * len(dataset))
80
+ test_size = len(dataset) - train_size - val_size
81
+ train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
82
+
83
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
84
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
85
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
86
+
87
+ return "Preparação dos dados concluída com sucesso."
88
+
89
+ #
90
+ def start_training(model_name, epochs, lr):
91
+ global model, train_loader, val_loader, device
92
+
93
+ if train_loader is None or val_loader is None:
94
+ return "Erro: Dados não preparados."
95
+
96
+ model = model_dict[model_name](pretrained=True)
97
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
98
+ model = model.to(device)
99
+
100
+ criterion = nn.CrossEntropyLoss()
101
+ optimizer = optim.Adam(model.parameters(), lr=float(lr))
102
+
103
+ for epoch in range(int(epochs)):
104
+ model.train()
105
+ running_loss = 0.0
106
+ for inputs, labels in train_loader:
107
+ inputs, labels = inputs.to(device), labels.to(device)
108
+ optimizer.zero_grad()
109
+ outputs = model(inputs)
110
+ loss = criterion(outputs, labels)
111
+ loss.backward()
112
+ optimizer.step()
113
+ running_loss += loss.item()
114
+
115
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
116
+
117
+ torch.save(model.state_dict(), 'modelo.pth')
118
+ return f"Treinamento concluído com sucesso. Modelo salvo."
119
+
120
+ #
121
+ def evaluate_model(loader):
122
+ global model, device, num_classes
123
+
124
+ if model is None:
125
+ return "Erro: Modelo não treinado."
126
+
127
+ if loader is None:
128
+ return "Erro: Conjunto de dados de teste não está preparado."
129
+
130
+ model.eval()
131
+ all_preds = []
132
+ all_labels = []
133
+ try:
134
+ with torch.no_grad():
135
+ for inputs, labels in loader:
136
+ inputs, labels = inputs.to(device), labels.to(device)
137
+ outputs = model(inputs)
138
+ _, preds = torch.max(outputs, 1)
139
+ all_preds.extend(preds.cpu().numpy())
140
+ all_labels.extend(labels.cpu().numpy())
141
+
142
+ report = classification_report(all_labels, all_preds, labels=list(range(num_classes)), target_names=[f"class_{i}" for i in range(num_classes)], zero_division=0)
143
+ return report
144
+ except Exception as e:
145
+ return f"Erro durante a avaliação: {str(e)}"
146
+
147
+ #
148
+ def show_confusion_matrix(loader):
149
+ global model, device, num_classes
150
+
151
+ if model is None:
152
+ return "Erro: Modelo não treinado."
153
+
154
+ model.eval()
155
+ all_preds = []
156
+ all_labels = []
157
+ with torch.no_grad():
158
+ for inputs, labels in loader:
159
+ inputs, labels = inputs.to(device), labels.to(device)
160
+ outputs = model(inputs)
161
+ _, preds = torch.max(outputs, 1)
162
+ all_preds.extend(preds.cpu().numpy())
163
+ all_labels.extend(labels.cpu().numpy())
164
+
165
+ cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
166
+
167
+ plt.figure(figsize=(6, 4.8))
168
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[f"class_{i}" for i in range(num_classes)], yticklabels=[f"class_{i}" for i in range(num_classes)])
169
+ plt.xlabel('Predictions')
170
+ plt.ylabel('Actuals')
171
+ buf = io.BytesIO()
172
+ plt.savefig(buf, format='png')
173
+ plt.close()
174
+ buf.seek(0)
175
+ return Image.open(buf)
176
+
177
+ #
178
+ def predict_images(images):
179
+ global model, device, num_classes
180
+
181
+ if model is None:
182
+ return "Erro: Modelo não treinado."
183
+
184
+ transform = transforms.Compose([
185
+ transforms.Resize((224, 224)),
186
+ transforms.ToTensor(),
187
+ ])
188
+
189
+ model.eval()
190
+ results = []
191
+
192
+ for image in images:
193
+ try:
194
+ img = transform(Image.open(image)).unsqueeze(0).to(device)
195
+ with torch.no_grad():
196
+ outputs = model(img)
197
+ _, preds = torch.max(outputs, 1)
198
+ predicted_class = preds.item()
199
+ results.append(f"Imagem {os.path.basename(image)} - Classe prevista: class_{predicted_class}")
200
+ except Exception as e:
201
+ results.append(f"Erro ao processar a imagem {image}: {str(e)}")
202
+
203
+ return results
204
+
205
+ #
206
+ def export_model(format):
207
+ global model
208
+
209
+ if model is None:
210
+ return "Erro: Modelo não treinado."
211
+
212
+ file_path = f"modelo_exportado.{format}"
213
+ if format == "pth":
214
+ torch.save(model.state_dict(), file_path)
215
+ elif format == "onnx":
216
+ try:
217
+ dummy_input = torch.randn(1, 3, 224, 224).to(device)
218
+ torch.onnx.export(model, dummy_input, file_path, export_params=True, opset_version=10, input_names=['input'], output_names=['output'])
219
+ except Exception as e:
220
+ return f"Erro ao exportar para ONNX: {str(e)}"
221
+ elif format == "pkl":
222
+ joblib.dump(model, file_path)
223
+ else:
224
+ return f"Formato {format} não suportado."
225
+
226
+ return f"Modelo exportado com sucesso para {file_path}"
227
+
228
+ #
229
+ def setup_test_classes():
230
+ global test_class_dirs, test_dataset_path
231
+
232
+ if os.path.exists(test_dataset_path):
233
+ shutil.rmtree(test_dataset_path)
234
+ os.makedirs(test_dataset_path)
235
+
236
+ #
237
+ test_class_dirs = [os.path.join(test_dataset_path, f'class_{i}') for i in range(num_classes)]
238
+ for class_dir in test_class_dirs:
239
+ os.makedirs(class_dir)
240
+
241
+ return f"Criados {num_classes} diretórios para classes de teste."
242
+
243
+ #
244
+ def upload_test_images(class_id, images):
245
+ class_dir = test_class_dirs[int(class_id)]
246
+ for image in images:
247
+ shutil.copy(image, class_dir)
248
+ return f"Imagens de teste salvas na classe {class_id}."
249
+
250
+ #
251
+ def prepare_test_data(batch_size=32, resize=(224, 224)):
252
+ global test_loader, num_classes
253
+
254
+ transform = transforms.Compose([
255
+ transforms.Resize(resize),
256
+ transforms.ToTensor(),
257
+ ])
258
+
259
+ test_dataset = datasets.ImageFolder(test_dataset_path, transform=transform)
260
+
261
+ if len(test_dataset.classes) != num_classes:
262
+ return f"Erro: Número de classes detectadas ({len(test_dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
263
+
264
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
265
+
266
+ return "Preparação dos dados de teste concluída com sucesso."
267
+
268
+ #
269
+ def main():
270
+ with gr.Blocks() as demo:
271
+ gr.Markdown("# Image Classification Training")
272
+
273
+ with gr.Tab("Configurar Classes"):
274
+ num_classes_input = gr.Number(label="Número de Classes", value=2, precision=0)
275
+ setup_button = gr.Button("Configurar Classes")
276
+ setup_output = gr.Textbox()
277
+ setup_button.click(setup_classes, inputs=num_classes_input, outputs=setup_output)
278
+
279
+ with gr.Tab("Upload de Imagens"):
280
+ upload_inputs = []
281
+ for i in range(num_classes):
282
+ with gr.Column():
283
+ gr.Markdown(f"### Classe {i}")
284
+ class_id = gr.Number(label=f"ID da Classe {i}", value=i, precision=0)
285
+ images = gr.File(label="Upload de Imagens", file_count="multiple", type="filepath")
286
+ upload_button = gr.Button("Upload")
287
+ upload_output = gr.Textbox()
288
+
289
+ upload_inputs.append((class_id, images, upload_button, upload_output))
290
+ upload_button.click(upload_images, inputs=[class_id, images], outputs=upload_output)
291
+
292
+ with gr.Tab("Preparação de Dados"):
293
+ batch_size = gr.Number(label="Tamanho do Batch", value=32)
294
+ resize = gr.Textbox(label="Resize (Ex: 224,224)", value="224,224")
295
+ prepare_button = gr.Button("Preparar Dados")
296
+ prepare_output = gr.Textbox()
297
+ prepare_button.click(lambda batch_size, resize: prepare_data(batch_size=batch_size, resize=tuple(map(int, resize.split(',')))), inputs=[batch_size, resize], outputs=prepare_output)
298
+
299
+ with gr.Tab("Treinamento"):
300
+ model_name = gr.Dropdown(label="Modelo", choices=list(model_dict.keys()))
301
+ epochs = gr.Number(label="Épocas", value=30)
302
+ lr = gr.Number(label="Taxa de Aprendizado", value=0.001)
303
+ train_button = gr.Button("Iniciar Treinamento")
304
+ train_output = gr.Textbox()
305
+ train_button.click(start_training, inputs=[model_name, epochs, lr], outputs=train_output)
306
+
307
+ with gr.Tab("Avaliação do Modelo"):
308
+ eval_button = gr.Button("Avaliar Modelo")
309
+ eval_output = gr.Textbox()
310
+ eval_button.click(lambda: evaluate_model(test_loader), outputs=eval_output)
311
+
312
+ cm_button = gr.Button("Mostrar Matriz de Confusão")
313
+ cm_output = gr.Image()
314
+ cm_button.click(lambda: show_confusion_matrix(test_loader), outputs=cm_output)
315
+
316
+ with gr.Tab("Predição e Avaliação"):
317
+ predict_images_input = gr.File(label="Upload de Imagens para Predição", file_count="multiple", type="filepath")
318
+ predict_button = gr.Button("Predizer")
319
+ predict_output = gr.Textbox()
320
+ predict_button.click(predict_images, inputs=predict_images_input, outputs=predict_output)
321
+
322
+ gr.Markdown("### Upload de Imagens de Teste")
323
+ setup_test_button = gr.Button("Configurar Diretórios de Teste")
324
+ setup_test_output = gr.Textbox()
325
+ setup_test_button.click(setup_test_classes, outputs=setup_test_output)
326
+
327
+ upload_test_inputs = []
328
+ for i in range(num_classes):
329
+ with gr.Column():
330
+ gr.Markdown(f"### Classe de Teste {i}")
331
+ test_class_id = gr.Number(label=f"ID da Classe {i}", value=i, precision=0)
332
+ test_images = gr.File(label="Upload de Imagens de Teste", file_count="multiple", type="filepath")
333
+ upload_test_button = gr.Button("Upload Imagens de Teste")
334
+ upload_test_output = gr.Textbox()
335
+
336
+ upload_test_inputs.append((test_class_id, test_images, upload_test_button, upload_test_output))
337
+ upload_test_button.click(upload_test_images, inputs=[test_class_id, test_images], outputs=upload_test_output)
338
+
339
+ prepare_test_button = gr.Button("Preparar Dados de Teste")
340
+ prepare_test_output = gr.Textbox()
341
+ prepare_test_button.click(lambda batch_size, resize: prepare_test_data(batch_size=batch_size, resize=tuple(map(int, resize.split(',')))), inputs=[batch_size, resize], outputs=prepare_test_output)
342
+
343
+ eval_test_button = gr.Button("Avaliar Conjunto de Teste")
344
+ eval_test_output = gr.Textbox()
345
+ eval_test_button.click(lambda: evaluate_model(test_loader), outputs=eval_test_output)
346
+
347
+ cm_test_button = gr.Button("Mostrar Matriz de Confusão do Conjunto de Teste")
348
+ cm_test_output = gr.Image()
349
+ cm_test_button.click(lambda: show_confusion_matrix(test_loader), outputs=cm_test_output)
350
+
351
+ with gr.Tab("Exportação"):
352
+ export_format = gr.Radio(label="Formato", choices=["pth", "onnx", "pkl"])
353
+ export_button = gr.Button("Exportar Modelo")
354
+ export_output = gr.Textbox()
355
+ export_button.click(export_model, inputs=export_format, outputs=export_output)
356
+
357
+ demo.launch()
358
+
359
+ if __name__ == "__main__":
360
+ main()
requirements.txt CHANGED
@@ -1,6 +1,12 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
1
+ torch==1.11.0
2
+ torchvision==0.12.0
3
+ scikit-learn==0.24.2
4
+ matplotlib==3.4.2
5
+ seaborn==0.11.1
6
+ numpy==1.21.0
7
+ Pillow==8.2.0
8
+ gradio==4.12.0
9
+ joblib==1.0.1
10
+ onnx==1.10.1
11
+ onnx-tf==1.8.0
12
+ tensorflow==2.16.0