MBARI_Monterey_Bay_Benthic / tator_inference.py
Jordan Pierce
initial commit
6b38cd2
raw
history blame
No virus
3.1 kB
import os
import logging
from tempfile import TemporaryFile
import cv2
import numpy as np
from PIL import Image
import tator
import inference
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Read environment variables that are provided from TATOR
host = os.getenv('HOST')
token = os.getenv('TOKEN')
project_id = int(os.getenv('PROJECT_ID'))
media_ids = [int(id_) for id_ in os.getenv('MEDIA_IDS').split(',')]
frames_per_inference = int(os.getenv('FRAMES_PER_INFERENCE', 30))
# Set up the TATOR API.
api = tator.get_api(host, token)
# Iterate through each video.
for media_id in media_ids:
# Download video.
media = api.get_media(media_id)
logger.info(f"Downloading {media.name}...")
out_path = f"/tmp/{media.name}"
for progress in tator.util.download_media(api, media, out_path):
logger.info(f"Download progress: {progress}%")
# Do inference on each video.
logger.info(f"Doing inference on {media.name}...")
localizations = []
vid = cv2.VideoCapture(out_path)
frame_number = 0
# Read *every* frame from the video, break when at the end.
while True:
ret, frame = vid.read()
if not ret:
break
# Create a temporary file, access the image data, save data to file.
framefile = TemporaryFile(suffix='.jpg')
im = Image.fromarray(frame)
im.save(framefile)
# For every N frames, make a prediction; append prediction results
# to a list, increase the frame count.
if frame_number % frames_per_inference == 0:
spec = {}
# Predictions contains all information inside pandas dataframe
predictions = inference.run_inference(framefile)
for i, r in predictions.pandas().xyxy[0].iterrows:
spec['media_id'] = media_id
spec['type'] = None # Unsure, docs not specific
spec['frame'] = frame_number
x, y, x2, y2 = r['xmin'], r['ymin'], r['xmax'], r['ymax']
w, h = x2 - x, y2 - y
spec['x'] = x
spec['y'] = y
spec['width'] = w
spec['height'] = h
spec['class_category'] = r['name']
spec['confidence'] = r['confidence']
localizations.append(spec)
frame_number += 1
# End interaction with video properly.
vid.release()
logger.info(f"Uploading object detections on {media.name}...")
# Create the localizations in the video.
num_created = 0
for response in tator.util.chunked_create(api.create_localization_list,
project_id,
localization_spec=localizations):
num_created += len(response.id)
# Output pretty logging information.
logger.info(f"Successfully created {num_created} localizations on "
f"{media.name}!")
logger.info("-------------------------------------------------")
logger.info(f"Completed inference on {len(media_ids)} files.")