Spaces:
Runtime error
Runtime error
Update models/models/detectors/pam.py
Browse files
models/models/detectors/pam.py
CHANGED
@@ -135,9 +135,9 @@ class PoseAnythingModel(BasePose):
|
|
135 |
"""Defines the computation performed at every call."""
|
136 |
str_dict = json.loads(str)
|
137 |
|
138 |
-
str_dict["img_q"] = torch.tensor(str_dict["img_q"], dtype=torch.float32)
|
139 |
-
str_dict["target_weight_s"] = torch.tensor(str_dict["target_weight_s"], dtype=torch.float32)
|
140 |
-
str_dict["target_s"] = torch.tensor(str_dict["target_s"], dtype=torch.float32)
|
141 |
|
142 |
str_dict['img_metas'][0]['sample_joints_3d'][0] = torch.tensor(str_dict['img_metas'][0]['sample_joints_3d'][0])
|
143 |
str_dict['img_metas'][0]['query_joints_3d'] = torch.tensor(str_dict['img_metas'][0]['query_joints_3d'])
|
|
|
135 |
"""Defines the computation performed at every call."""
|
136 |
str_dict = json.loads(str)
|
137 |
|
138 |
+
str_dict["img_q"] = torch.tensor(str_dict["img_q"], dtype=torch.float32).cuda()
|
139 |
+
str_dict["target_weight_s"] = torch.tensor(str_dict["target_weight_s"], dtype=torch.float32).cuda()
|
140 |
+
str_dict["target_s"] = torch.tensor(str_dict["target_s"], dtype=torch.float32).cuda()
|
141 |
|
142 |
str_dict['img_metas'][0]['sample_joints_3d'][0] = torch.tensor(str_dict['img_metas'][0]['sample_joints_3d'][0])
|
143 |
str_dict['img_metas'][0]['query_joints_3d'] = torch.tensor(str_dict['img_metas'][0]['query_joints_3d'])
|