Matteo Sirri commited on
Commit
f6bb7f6
β€’
1 Parent(s): 8499b06

perf: move tensor in gpu

Browse files
input_examples/input_examples/001.jpg β†’ 001.jpg RENAMED
File without changes
input_examples/input_examples/002.jpg β†’ 002.jpg RENAMED
File without changes
input_examples/input_examples/003.jpg β†’ 003.jpg RENAMED
File without changes
input_examples/input_examples/004.jpg β†’ 004.jpg RENAMED
File without changes
input_examples/input_examples/005.jpg β†’ 005.jpg RENAMED
File without changes
input_examples/input_examples/006.jpg β†’ 006.jpg RENAMED
File without changes
input_examples/input_examples/007.jpg β†’ 007.jpg RENAMED
File without changes
app.py CHANGED
@@ -9,6 +9,8 @@ from src.detection.graph_utils import add_bbox
9
  from src.detection.vision import presets
10
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
11
 
 
 
12
 
13
  def load_model(baseline: bool = False):
14
  if baseline:
@@ -21,7 +23,6 @@ def load_model(baseline: bool = False):
21
  checkpoint = torch.load(
22
  "model_split_3_FT_MOT17.pth", map_location="cpu")
23
  model.load_state_dict(checkpoint["model"])
24
- device = torch.device('cuda:0')
25
  model.to(device)
26
  model.eval()
27
  return model
@@ -31,6 +32,7 @@ def frcnn_motsynth(image):
31
  model = load_model(baseline=True)
32
  transformEval = presets.DetectionPresetEval()
33
  image_tensor = transformEval(image, None)[0]
 
34
  prediction = model([image_tensor])[0]
35
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
36
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
@@ -41,6 +43,7 @@ def frcnn_coco(image):
41
  model = load_model(baseline=True)
42
  transformEval = presets.DetectionPresetEval()
43
  image_tensor = transformEval(image, None)[0]
 
44
  prediction = model([image_tensor])[0]
45
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
46
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
@@ -49,7 +52,8 @@ def frcnn_coco(image):
49
 
50
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
51
  description = "![alt text](http://www.aiacademy.unimore.it/media/news/ai-logo-white_2ND_EDITION.png)"
52
- examples = "input_examples"
 
53
 
54
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
55
  type="file", shape=(1920, 1080), label="Baseline Model trained on COCO + FT on MOT17"))
@@ -58,4 +62,4 @@ io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
58
  type="file", shape=(1920, 1080), label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
59
 
60
  gr.Parallel(io_baseline, io_custom, title=title,
61
- description=description, examples=examples,theme="huggingface").launch(enable_queue=True)
 
9
  from src.detection.vision import presets
10
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
11
 
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
 
15
  def load_model(baseline: bool = False):
16
  if baseline:
 
23
  checkpoint = torch.load(
24
  "model_split_3_FT_MOT17.pth", map_location="cpu")
25
  model.load_state_dict(checkpoint["model"])
 
26
  model.to(device)
27
  model.eval()
28
  return model
 
32
  model = load_model(baseline=True)
33
  transformEval = presets.DetectionPresetEval()
34
  image_tensor = transformEval(image, None)[0]
35
+ image_tensor = image_tensor.to(device)
36
  prediction = model([image_tensor])[0]
37
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
38
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
 
43
  model = load_model(baseline=True)
44
  transformEval = presets.DetectionPresetEval()
45
  image_tensor = transformEval(image, None)[0]
46
+ image_tensor = image_tensor.to(device)
47
  prediction = model([image_tensor])[0]
48
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
49
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
 
52
 
53
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
54
  description = "![alt text](http://www.aiacademy.unimore.it/media/news/ai-logo-white_2ND_EDITION.png)"
55
+ examples = ["001.jpg", "002.jpg", "003.jpg",
56
+ "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
57
 
58
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
59
  type="file", shape=(1920, 1080), label="Baseline Model trained on COCO + FT on MOT17"))
 
62
  type="file", shape=(1920, 1080), label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
63
 
64
  gr.Parallel(io_baseline, io_custom, title=title,
65
+ description=description, examples=examples, theme="default").launch(enable_queue=True)
input_examples/log.csv DELETED
@@ -1,7 +0,0 @@
1
- "001.jpg"
2
- "002.jpg"
3
- "003.jpg"
4
- "004.jpg"
5
- "005.jpg"
6
- "006.jpg"
7
- "007.jpg"