asl-detection-yolov5 / inference.py
niki-stha's picture
Create inference.py
d1f4a84
raw
history blame
1.31 kB
import torch
import torchvision.transforms as T
from PIL import Image
# Load the YOLOv5 model
model = torch.hub.load('niki-stha/asl-detection-yolov5', 'yolov5s')
# Set the device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device).eval()
# Define the image transformation
transform = T.Compose([
T.Resize((416, 416)),
T.ToTensor(),
])
# Inference function
def run_inference(image):
# Preprocess the image
image = transform(image).unsqueeze(0).to(device)
# Perform inference
results = model(image)
# Post-process the results
# (You can customize this part based on your specific requirements)
predictions = results.pandas().xyxy[0]
return predictions
# Example API endpoint
def inference_api(request):
# Get the image from the request (you may need to adapt this based on your API framework)
image_data = request.files['image'].read()
image = Image.open(io.BytesIO(image_data))
# Run inference
predictions = run_inference(image)
# Convert predictions to JSON or any other desired format
# (You may need to adapt this based on your API framework)
response = {
'predictions': predictions.to_dict(orient='records')
}
return jsonify(response)