Spaces:
Sleeping
Sleeping
from pytorch_model_factory import TorchModelFactory | |
from detectron2.data import MetadataCatalog | |
import torch | |
import torchvision.transforms as transforms | |
class PytorchPredictor: | |
def __init__(self, cfg): | |
self.cfg = cfg.clone() # cfg can be modified by model | |
self.task_type = cfg.TASK_TYPE | |
self.resize = 256 | |
self.crop = 224 | |
if self.task_type == "classfication": | |
self.model = TorchModelFactory.create_feature_extract_model("resnet") | |
elif self.task_type == "feature": | |
self.model = TorchModelFactory.create_feature_extract_model("resnet") | |
elif self.task_type == "semantic": | |
self.model = TorchModelFactory.create_semantic_model("deeplabv3") | |
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
if not hasattr(self.metadata,"stuff_classes"): | |
self.metadata.stuff_classes = self.metadata.thing_classes | |
if len(self.metadata.stuff_classes) == 20: | |
self.metadata.stuff_classes.insert(0, "background") | |
print(self.metadata) | |
self.resize = None | |
self.crop = None | |
def __call__(self, image): | |
""" | |
Args: | |
image (PIL image): an image of shape (H, W, C) (in BGR order). | |
Returns: | |
predictions (dict): | |
the output of the model for one image only. | |
See :doc:`/tutorials/models` for details about the format. | |
""" | |
if self.model is None: | |
return None | |
image = self.image_processor(image) | |
input_batch = image.unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_batch = input_batch.cuda() | |
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 | |
if self.model is None: | |
return None | |
predictions = self.model(input_batch) | |
return self._post_processor(predictions) | |
def image_processor(self,input_image): | |
# from PIL import Image | |
# input_image = Image.open(image_path).convert('RGB') | |
preprocess = transforms.Compose([ | |
transforms.Resize(self.resize) if self.resize is not None else lambda x: x, | |
transforms.CenterCrop(self.crop) if self.crop is not None else lambda x: x, | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
input_tensor = preprocess(input_image) | |
return input_tensor | |
def _post_processor(self,output): | |
result = None | |
if self.task_type == "classfication": | |
output = output.cpu() | |
result = {"classfication":[]} | |
probabilities = torch.nn.functional.softmax(output, dim=1) | |
for i, probabilitiy in enumerate(probabilities): | |
top5_prob, top5_catid = torch.topk(probabilitiy, 1) | |
target = {"feature":output[i],"score":top5_prob,"pred_class":top5_catid} | |
result["classfication"].append(target) | |
elif self.task_type == "feature": | |
output = output.cpu() | |
result = {"features":output} | |
elif self.task_type == "semantic": | |
output = output["out"] | |
output_predictions = output.argmax(1) | |
output_predictions = output_predictions.cpu() | |
result = {"sem_segs":output_predictions} | |
return result | |
def release(self): | |
import gc | |
# 删除模型对象 | |
del self.model | |
# 清除GPU缓存 | |
if self.cfg.MODEL.DEVICE == "gpu": | |
torch.cuda.empty_cache() | |
# 手动触发垃圾回收 | |
gc.collect() | |