JulianPhillips's picture
Update app.py
6326d5f verified
raw
history blame
1.95 kB
from flask import Flask, request, jsonify
import torch
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load Meta Sapiens Pose model
sapiens_model = torch.jit.load('/models/sapiens_pose/model.pt')
sapiens_model.eval()
# Load MotionBERT model
motionbert_model = AutoModelForSequenceClassification.from_pretrained('/models/motionbert')
motionbert_tokenizer = AutoTokenizer.from_pretrained('/models/motionbert')
# Flask app
app = Flask(__name__)
# Define a transformation for input images
transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize image to the required size
transforms.ToTensor(), # Convert image to PyTorch tensor
])
@app.route('/pose_estimation', methods=['POST'])
def pose_estimation():
try:
# Accept an image file as input for pose estimation
image = request.files['image']
img = Image.open(BytesIO(image.read()))
# Preprocess the image
img_tensor = transform(img).unsqueeze(0) # Add batch dimension
# Perform pose estimation
with torch.no_grad():
pose_result = sapiens_model(img_tensor)
return jsonify({"pose_result": pose_result.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/sequence_analysis', methods=['POST'])
def sequence_analysis():
try:
# Accept keypoint data as input for sequence analysis
keypoints = request.json['keypoints']
inputs = motionbert_tokenizer(keypoints, return_tensors="pt")
with torch.no_grad():
sequence_output = motionbert_model(**inputs)
return jsonify({"sequence_analysis": sequence_output.logits.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)