from typing import Dict, List, Any from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel import torch import base64 import logging import numpy as np import gc from PIL import Image from io import BytesIO import subprocess logger = logging.getLogger() logger.setLevel(logging.DEBUG) # check for GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") gpu_count = torch.cuda.device_count() class EndpointHandler: def __init__(self, path=""): # load the model self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") if(gpu_count > 1): Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto") logger.info(model.hf_device_map) model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"] model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"] self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map) print(subprocess.run(["nvidia-smi"])) else: self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64") # move model to device self.model.to(device) def __call__(self, data: Any): """ Args: data (:obj:): binary image data to be labeled Return: A :obj:`string`:. Base64 encoded image string """ image = data["inputs"] if(gpu_count > 1): inputs = self.processor(image, return_tensors="pt") else: inputs = self.processor(image, return_tensors="pt").to(device) try: with torch.no_grad(): outputs = self.model(**inputs) print(subprocess.run(["nvidia-smi"])) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) img = Image.fromarray(output) buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return img_str.decode() except Exception as e: logger.error(str(e)) del inputs gc.collect() torch.cuda.empty_cache() return {"error": str(e)}