Mrmusculo commited on
Commit
2e2ab76
1 Parent(s): 3a0bfe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -40
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import pickle
2
- import pandas as pd
3
  import numpy as np
4
 
5
  import requests
@@ -14,6 +14,7 @@ import PIL
14
 
15
  from transparent_background import Remover
16
  import torch
 
17
 
18
  import time
19
 
@@ -22,6 +23,9 @@ from PIL import Image
22
  import requests
23
  from io import BytesIO
24
 
 
 
 
25
  class BackgroundRemover(Remover):
26
  def __init__(self, model_bytes, device=None):
27
  """
@@ -36,7 +40,7 @@ class BackgroundRemover(Remover):
36
 
37
 
38
  # get the path of the script that defines this class
39
- script_path = os.path.abspath(__file__)
40
 
41
  # construct the path to the arial.ttf file relative to the script location
42
  font_path = os.path.join(os.path.dirname(script_path), "arial.ttf")
@@ -152,9 +156,7 @@ class BackgroundRemover(Remover):
152
  model_bytes = file.read()
153
 
154
  return model_bytes
155
-
156
-
157
-
158
  def show_image(url: str):
159
  response = requests.get(url)
160
  img = Image.open(BytesIO(response.content))
@@ -171,50 +173,25 @@ def do_predictions(url):
171
  # Set up data transformations
172
  data_transforms = {
173
  'train': transforms.Compose([
174
- #transforms.Resize(512), #256
175
- #transforms.CenterCrop(480), # 224
176
- #transforms.Resize((256, 256)),
177
  transforms.Resize((384, 384)),
178
  transforms.ToTensor(),
179
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
180
  ]),
181
  'val': transforms.Compose([
182
- #transforms.Resize(512), #256
183
- #transforms.CenterCrop(480), # 224
184
- #transforms.Resize((256, 256)),
185
- transforms.Resize((384, 284)),
186
  transforms.ToTensor(),
187
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
188
  ]),
189
  }
190
-
191
- # Crear un modelo con la misma arquitectura
192
- detect_model = models.resnet50(weights=None) # Cambiar 'pretrained' por 'weights'
193
- num_ftrs = detect_model.fc.in_features
194
- num_classes = 2
195
- detect_model.fc = nn.Linear(num_ftrs, num_classes)
196
- detect_model = detect_model.to(device)
197
-
198
- # Cargar los pesos guardados
199
- model_weights_path = 'white_background_detection/resnet50_finetuned_weights.pth'
200
- detect_model.load_state_dict(torch.load(model_weights_path))
201
-
202
- # Cambiar el modelo a modo de evaluación
203
- detect_model.eval()
204
-
205
- print("")
206
 
207
- prediction, predicted_probability, inference_time = predict_single_image_detection(img, detect_model, data_transforms['val'], "cuda:0")
208
 
209
- if prediction=="real":
210
- out = transform_model.predict(img, comparison=False)
211
- return prediction, predicted_probability, img, out,
212
- else:
213
- return prediction, predicted_probability, img, None
214
-
215
- iface = gr.Interface(fn=do_predictions, inputs="text", outputs=["text", "text", "image", "image"], examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"],
216
- ["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]])
217
- #iface.outputs[0].set_title("Predicción")
218
- #iface.outputs[1].set_title("Clase")
219
- #iface.outputs[2].set_title("Probabilidad")
220
  iface.launch(share=True)
 
1
  import pickle
2
+ import pandas as p
3
  import numpy as np
4
 
5
  import requests
 
14
 
15
  from transparent_background import Remover
16
  import torch
17
+ import torch.nn.functional as F
18
 
19
  import time
20
 
 
23
  import requests
24
  from io import BytesIO
25
 
26
+ from torchvision import datasets, models, transforms
27
+
28
+
29
  class BackgroundRemover(Remover):
30
  def __init__(self, model_bytes, device=None):
31
  """
 
40
 
41
 
42
  # get the path of the script that defines this class
43
+ script_path = "" #os.path.abspath(__file__)
44
 
45
  # construct the path to the arial.ttf file relative to the script location
46
  font_path = os.path.join(os.path.dirname(script_path), "arial.ttf")
 
156
  model_bytes = file.read()
157
 
158
  return model_bytes
159
+
 
 
160
  def show_image(url: str):
161
  response = requests.get(url)
162
  img = Image.open(BytesIO(response.content))
 
173
  # Set up data transformations
174
  data_transforms = {
175
  'train': transforms.Compose([
 
 
 
176
  transforms.Resize((384, 384)),
177
  transforms.ToTensor(),
178
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
179
  ]),
180
  'val': transforms.Compose([
181
+ transforms.Resize((384, 384)),
 
 
 
182
  transforms.ToTensor(),
183
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
184
  ]),
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ out = transform_model.predict(img, comparison=False)
188
 
189
+ return img, out
190
+
191
+ iface = gr.Interface(fn=do_predictions, inputs="text",
192
+ examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"],
193
+ ["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]],
194
+ outputs=["image", "image"],
195
+ )
196
+
 
 
 
197
  iface.launch(share=True)