Spaces:
Runtime error
Runtime error
| from flask import Flask, request | |
| from flask_cors import CORS, cross_origin | |
| import torch | |
| import model | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| embedding_dim = 384 | |
| hidden_dim = 512 | |
| num_layers = 1 | |
| output_dim = 180 | |
| num_epochs = 100 | |
| learning_rate = 0.001 | |
| lstm_model = model.LSTM(embedding_dim, hidden_dim, num_layers, output_dim) | |
| lstm_model.load_state_dict(torch.load('lstm.pt')) | |
| app = Flask(__name__) | |
| cors = CORS(app) | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| def GeneratePosesJSON(input): | |
| with torch.no_grad(): | |
| processed_text = torch.tensor(sentence_model.encode(input), dtype=torch.float) | |
| output_poses = lstm_model(processed_text.unsqueeze(0)) | |
| people = output_poses.cpu().detach().numpy().reshape(5, 18, 2).tolist() | |
| newPeople = [] | |
| for person in people: | |
| newPerson = [] | |
| for keypoints in person: | |
| newPerson.append([keypoints[0], keypoints[1], 1]) | |
| newPeople.append(newPerson) | |
| data = np.array(newPeople).reshape(5, 54).tolist() | |
| formatted_data = [] | |
| for person in data: | |
| formatted_data.append({ "pose_keypoints_2d": person }) | |
| return { 'people': formatted_data, 'animals': [], 'canvas_width': 900, 'canvas_height': 300 } | |
| def hello(): | |
| return "Hello, World!" | |
| def generatePose(): | |
| text = request.args.get('text') | |
| data = GeneratePosesJSON(text) | |
| return data | |
| if __name__ == '__main__': | |
| app.run() |