DHEIVER commited on
Commit
5208152
1 Parent(s): f2142fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from torchvision import transforms
3
  import gradio as gr
 
4
 
5
  # Carregue o dicionário contendo o modelo PyTorch treinado
6
  model_dict = torch.load("best.pt", map_location=torch.device('cpu')) # Use 'cpu' se não estiver usando GPU
@@ -13,7 +14,6 @@ model.eval()
13
 
14
  # Transformação de pré-processamento
15
  preprocess = transforms.Compose([
16
- transforms.ToPILImage(),
17
  transforms.Resize((224, 224)),
18
  transforms.ToTensor(),
19
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -29,12 +29,34 @@ def predict(image):
29
  with torch.no_grad():
30
  output = model(input_batch)
31
 
32
- # Post-processamento, se necessário
33
- # ...
 
 
 
34
 
35
- # Retorna o resultado da inferência
36
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Interface Gradio
39
- iface = gr.Interface(fn=predict, inputs="image", outputs="text")
40
  iface.launch()
 
1
  import torch
2
  from torchvision import transforms
3
  import gradio as gr
4
+ from PIL import Image, ImageDraw
5
 
6
  # Carregue o dicionário contendo o modelo PyTorch treinado
7
  model_dict = torch.load("best.pt", map_location=torch.device('cpu')) # Use 'cpu' se não estiver usando GPU
 
14
 
15
  # Transformação de pré-processamento
16
  preprocess = transforms.Compose([
 
17
  transforms.Resize((224, 224)),
18
  transforms.ToTensor(),
19
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
29
  with torch.no_grad():
30
  output = model(input_batch)
31
 
32
+ # Post-processamento para extrair informações de detecção
33
+ # (substitua as linhas abaixo de acordo com a estrutura de saída do seu modelo)
34
+ boxes = output['boxes']
35
+ labels = output['labels']
36
+ scores = output['scores']
37
 
38
+ # Visualize a imagem com caixas de detecção
39
+ image_with_boxes = visualize_detections(image, boxes, labels, scores)
40
+
41
+ # Retorna a imagem com caixas de detecção
42
+ return image_with_boxes
43
+
44
+ # Função para visualizar as caixas de detecção na imagem
45
+ def visualize_detections(image, boxes, labels, scores):
46
+ # Converta a imagem para o formato PIL
47
+ image_pil = transforms.ToPILImage()(image)
48
+
49
+ # Crie um objeto ImageDraw para desenhar caixas na imagem
50
+ draw = ImageDraw.Draw(image_pil)
51
+
52
+ # Desenhe as caixas de detecção na imagem
53
+ for box, label, score in zip(boxes, labels, scores):
54
+ box = [round(coord, 2) for coord in box.tolist()] # Arredonde as coordenadas da caixa
55
+ draw.rectangle(box, outline="red", width=3)
56
+ draw.text((box[0], box[1]), f"Label: {label}\nScore: {score:.2f}", fill="red")
57
+
58
+ return image_pil
59
 
60
  # Interface Gradio
61
+ iface = gr.Interface(fn=predict, inputs="image", outputs="image")
62
  iface.launch()