merve HF staff commited on
Commit
b942818
·
verified ·
1 Parent(s): 6e6de0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -45,6 +45,7 @@ def annotate_image(
45
  @spaces.GPU
46
  def process_video(
47
  input_video,
 
48
  progress=gr.Progress(track_tqdm=True)
49
  ):
50
  video_info = sv.VideoInfo.from_video_path(input_video)
@@ -59,7 +60,7 @@ def process_video(
59
  with sv.VideoSink(result_file_path, video_info=video_info) as sink:
60
  for _ in tqdm(range(total), desc="Processing video.."):
61
  frame = next(frame_generator)
62
- results = query(Image.fromarray(frame))
63
  final_labels = []
64
  detections = []
65
 
@@ -76,13 +77,13 @@ def process_video(
76
  return result_file_path
77
 
78
 
79
- def query(image):
80
  inputs = processor(images=image, return_tensors="pt").to(device)
81
  with torch.no_grad():
82
  outputs = model(**inputs)
83
  target_sizes = torch.Tensor([image.size])
84
 
85
- results = processor.post_process_object_detection(outputs=outputs, threshold=0.6, target_sizes=target_sizes)
86
  return results
87
 
88
  with gr.Blocks() as demo:
@@ -94,6 +95,7 @@ with gr.Blocks() as demo:
94
  input_video = gr.Video(
95
  label='Input Video'
96
  )
 
97
  submit = gr.Button()
98
  with gr.Column():
99
  output_video = gr.Video(
@@ -101,16 +103,17 @@ with gr.Blocks() as demo:
101
  )
102
  gr.Examples(
103
  fn=process_video,
104
- examples=[["./cats.mp4"]],
105
  inputs=[
106
- input_video
 
107
  ],
108
  outputs=output_video
109
  )
110
 
111
  submit.click(
112
  fn=process_video,
113
- inputs=input_video,
114
  outputs=output_video
115
  )
116
 
 
45
  @spaces.GPU
46
  def process_video(
47
  input_video,
48
+ confidence_threshold,
49
  progress=gr.Progress(track_tqdm=True)
50
  ):
51
  video_info = sv.VideoInfo.from_video_path(input_video)
 
60
  with sv.VideoSink(result_file_path, video_info=video_info) as sink:
61
  for _ in tqdm(range(total), desc="Processing video.."):
62
  frame = next(frame_generator)
63
+ results = query(Image.fromarray(frame), confidence_threshold)
64
  final_labels = []
65
  detections = []
66
 
 
77
  return result_file_path
78
 
79
 
80
+ def query(image, confidence_threshold):
81
  inputs = processor(images=image, return_tensors="pt").to(device)
82
  with torch.no_grad():
83
  outputs = model(**inputs)
84
  target_sizes = torch.Tensor([image.size])
85
 
86
+ results = processor.post_process_object_detection(outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes)
87
  return results
88
 
89
  with gr.Blocks() as demo:
 
95
  input_video = gr.Video(
96
  label='Input Video'
97
  )
98
+ conf = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.6, step=0.05)
99
  submit = gr.Button()
100
  with gr.Column():
101
  output_video = gr.Video(
 
103
  )
104
  gr.Examples(
105
  fn=process_video,
106
+ examples=[["./cats.mp4", 0.6]],
107
  inputs=[
108
+ input_video,
109
+ conf
110
  ],
111
  outputs=output_video
112
  )
113
 
114
  submit.click(
115
  fn=process_video,
116
+ inputs=[input_video, conf],
117
  outputs=output_video
118
  )
119