srinuksv commited on
Commit
9ca41bc
1 Parent(s): 6add49c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -9,19 +9,19 @@ model = torch.load(model_path, map_location=torch.device('cpu'))
9
  model.eval()
10
 
11
  # Define classes if not already defined
12
- classes = ['creatures',
13
- 'fish',
14
- 'jellyfish',
15
- 'penguin',
16
- 'puffin',
17
- 'shark',
18
- 'starfish',
19
- 'stingray'] # List of class labels
20
 
21
  # Define function for processing video
22
  def process_video(input_video):
 
 
 
 
 
 
 
23
  output_path = 'video_output.avi'
24
- cap = cv2.VideoCapture(input_video.name)
25
  fps = cap.get(cv2.CAP_PROP_FPS)
26
  fourcc = cv2.VideoWriter_fourcc(*'XVID')
27
  out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4))))
@@ -33,7 +33,7 @@ def process_video(input_video):
33
  if not ret:
34
  break
35
 
36
- img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32).to(device)
37
  img = img.unsqueeze(0)
38
 
39
  with torch.no_grad():
@@ -59,8 +59,8 @@ def process_video(input_video):
59
 
60
  return output_path
61
 
62
- video_input = gr.Video(label="Input Video")
63
- processed_video = gr.Image(label="Processed Video") # No 'outputs' submodule
64
 
65
  interface = gr.Interface(
66
  fn=process_video,
 
9
  model.eval()
10
 
11
  # Define classes if not already defined
12
+ classes = ['creatures', 'fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray'] # List of class labels
 
 
 
 
 
 
 
13
 
14
  # Define function for processing video
15
  def process_video(input_video):
16
+ if isinstance(input_video, str):
17
+ # This is the case when the input is a filename
18
+ input_video_path = input_video
19
+ else:
20
+ # This is the case when the input is a file object
21
+ input_video_path = input_video.name
22
+
23
  output_path = 'video_output.avi'
24
+ cap = cv2.VideoCapture(input_video_path)
25
  fps = cap.get(cv2.CAP_PROP_FPS)
26
  fourcc = cv2.VideoWriter_fourcc(*'XVID')
27
  out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4))))
 
33
  if not ret:
34
  break
35
 
36
+ img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32)
37
  img = img.unsqueeze(0)
38
 
39
  with torch.no_grad():
 
59
 
60
  return output_path
61
 
62
+ video_input = gr.inputs.Video(label="Input Video")
63
+ processed_video = gr.outputs.Video(label="Processed Video")
64
 
65
  interface = gr.Interface(
66
  fn=process_video,