|
import argparse |
|
import logging |
|
import random |
|
import uuid |
|
import numpy as np |
|
from transformers import pipeline |
|
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler |
|
from diffusers.utils import load_image |
|
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler |
|
from diffusers.utils import export_to_video |
|
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer |
|
from datasets import load_dataset |
|
from PIL import Image |
|
import io |
|
from torchvision import transforms |
|
import torch |
|
import torchaudio |
|
from speechbrain.pretrained import WaveformEnhancement |
|
import joblib |
|
from huggingface_hub import hf_hub_url, cached_download |
|
from transformers import AutoImageProcessor, TimesformerForVideoClassification |
|
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, AutoFeatureExtractor |
|
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector |
|
from controlnet_aux.open_pose.body import Body |
|
from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large |
|
from controlnet_aux.hed import Network |
|
from transformers import DPTForDepthEstimation, DPTFeatureExtractor |
|
import warnings |
|
import time |
|
from espnet2.bin.tts_inference import Text2Speech |
|
import soundfile as sf |
|
from asteroid.models import BaseModel |
|
import traceback |
|
import os |
|
import yaml |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default="config.yaml") |
|
args = parser.parse_args() |
|
|
|
if __name__ != "__main__": |
|
args.config = "config.gradio.yaml" |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
handler = logging.StreamHandler() |
|
handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
|
|
config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) |
|
|
|
local_deployment = config["local_deployment"] |
|
if config["inference_mode"] == "huggingface": |
|
local_deployment = "none" |
|
|
|
PROXY = None |
|
if config["proxy"]: |
|
PROXY = { |
|
"https": config["proxy"], |
|
} |
|
|
|
start = time.time() |
|
|
|
|
|
local_models = "" |
|
|
|
|
|
def load_pipes(local_deployment): |
|
other_pipes = {} |
|
standard_pipes = {} |
|
controlnet_sd_pipes = {} |
|
if local_deployment in ["full"]: |
|
other_pipes = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
"damo-vilab/text-to-video-ms-1.7b": { |
|
"model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": { |
|
"model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"microsoft/speecht5_vc":{ |
|
"processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"), |
|
"model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"), |
|
"vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"), |
|
"embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"facebook/maskformer-swin-base-coco": { |
|
"feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"), |
|
"model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"), |
|
"device": "cuda:0" |
|
}, |
|
"Intel/dpt-hybrid-midas": { |
|
"model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True), |
|
"feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"), |
|
"device": "cuda:0" |
|
} |
|
} |
|
|
|
if local_deployment in ["full", "standard"]: |
|
standard_pipes = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
"espnet/kan-bayashi_ljspeech_vits": { |
|
"model": Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
"runwayml/stable-diffusion-v1-5": { |
|
"model": DiffusionPipeline.from_pretrained(f"{local_models}runwayml/stable-diffusion-v1-5"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
"openai/whisper-base": { |
|
"model": pipeline(task="automatic-speech-recognition", model=f"{local_models}openai/whisper-base"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
"Intel/dpt-large": { |
|
"model": pipeline(task="depth-estimation", model=f"{local_models}Intel/dpt-large"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
"facebook/detr-resnet-50-panoptic": { |
|
"model": pipeline(task="image-segmentation", model=f"{local_models}facebook/detr-resnet-50-panoptic"), |
|
"device": "cuda:0" |
|
}, |
|
"facebook/detr-resnet-101": { |
|
"model": pipeline(task="object-detection", model=f"{local_models}facebook/detr-resnet-101"), |
|
"device": "cuda:0" |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"impira/layoutlm-document-qa": { |
|
"model": pipeline(task="document-question-answering", model=f"{local_models}impira/layoutlm-document-qa"), |
|
"device": "cuda:0" |
|
}, |
|
"ydshieh/vit-gpt2-coco-en": { |
|
"model": pipeline(task="image-to-text", model=f"{local_models}ydshieh/vit-gpt2-coco-en"), |
|
"device": "cuda:0" |
|
}, |
|
"dandelin/vilt-b32-finetuned-vqa": { |
|
"model": pipeline(task="visual-question-answering", model=f"{local_models}dandelin/vilt-b32-finetuned-vqa"), |
|
"device": "cuda:0" |
|
} |
|
} |
|
|
|
if local_deployment in ["full", "standard", "minimal"]: |
|
|
|
controlnet = ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) |
|
controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
f"{local_models}runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
hed_network = HEDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
|
controlnet_sd_pipes = { |
|
"openpose-control": { |
|
"model": OpenposeDetector.from_pretrained('lllyasviel/ControlNet') |
|
}, |
|
"mlsd-control": { |
|
"model": MLSDdetector.from_pretrained('lllyasviel/ControlNet') |
|
}, |
|
"hed-control": { |
|
"model": hed_network |
|
}, |
|
"scribble-control": { |
|
"model": hed_network |
|
}, |
|
"midas-control": { |
|
"model": MidasDetector.from_pretrained('lllyasviel/ControlNet') |
|
}, |
|
"canny-control": { |
|
"model": CannyDetector() |
|
}, |
|
"lllyasviel/sd-controlnet-canny":{ |
|
"control": controlnet, |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-depth":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-hed":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-mlsd":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-openpose":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-scribble":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
}, |
|
"lllyasviel/sd-controlnet-seg":{ |
|
"control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16), |
|
"model": controlnetpipe, |
|
"device": "cuda:0" |
|
} |
|
} |
|
pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes} |
|
return pipes |
|
|
|
pipes = load_pipes(local_deployment) |
|
|
|
end = time.time() |
|
during = end - start |
|
|
|
print(f"[ ready ] {during}s") |
|
|
|
def running(): |
|
return {"running": True} |
|
|
|
def status(model_id): |
|
disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"] |
|
if model_id in pipes.keys() and model_id not in disabled_models: |
|
print(f"[ check {model_id} ] success") |
|
return {"loaded": True} |
|
else: |
|
print(f"[ check {model_id} ] failed") |
|
return {"loaded": False} |
|
|
|
def models(model_id, data): |
|
while "using" in pipes[model_id] and pipes[model_id]["using"]: |
|
print(f"[ inference {model_id} ] waiting") |
|
time.sleep(0.1) |
|
pipes[model_id]["using"] = True |
|
print(f"[ inference {model_id} ] start") |
|
|
|
start = time.time() |
|
|
|
pipe = pipes[model_id]["model"] |
|
|
|
if "device" in pipes[model_id]: |
|
try: |
|
pipe.to(pipes[model_id]["device"]) |
|
except: |
|
pipe.device = torch.device(pipes[model_id]["device"]) |
|
pipe.model.to(pipes[model_id]["device"]) |
|
|
|
result = None |
|
try: |
|
|
|
if model_id == "damo-vilab/text-to-video-ms-1.7b": |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
prompt = data["text"] |
|
video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames |
|
file_name = str(uuid.uuid4())[:4] |
|
video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4") |
|
|
|
new_file_name = str(uuid.uuid4())[:4] |
|
os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4") |
|
|
|
if os.path.exists(f"public/videos/{new_file_name}.mp4"): |
|
result = {"path": f"/videos/{new_file_name}.mp4"} |
|
else: |
|
result = {"path": f"/videos/{file_name}.mp4"} |
|
|
|
|
|
if model_id.startswith("lllyasviel/sd-controlnet-"): |
|
pipe.controlnet.to('cpu') |
|
pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"]) |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
control_image = load_image(data["img_url"]) |
|
|
|
out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0] |
|
file_name = str(uuid.uuid4())[:4] |
|
out_image.save(f"public/images/{file_name}.png") |
|
result = {"path": f"/images/{file_name}.png"} |
|
|
|
if model_id.endswith("-control"): |
|
image = load_image(data["img_url"]) |
|
if "scribble" in model_id: |
|
control = pipe(image, scribble = True) |
|
elif "canny" in model_id: |
|
control = pipe(image, low_threshold=100, high_threshold=200) |
|
else: |
|
control = pipe(image) |
|
file_name = str(uuid.uuid4())[:4] |
|
control.save(f"public/images/{file_name}.png") |
|
result = {"path": f"/images/{file_name}.png"} |
|
|
|
|
|
if model_id == "lambdalabs/sd-image-variations-diffusers": |
|
im = load_image(data["img_url"]) |
|
file_name = str(uuid.uuid4())[:4] |
|
with open(f"public/images/{file_name}.png", "wb") as f: |
|
f.write(data) |
|
tform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize( |
|
(224, 224), |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
antialias=False, |
|
), |
|
transforms.Normalize( |
|
[0.48145466, 0.4578275, 0.40821073], |
|
[0.26862954, 0.26130258, 0.27577711]), |
|
]) |
|
inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0) |
|
out = pipe(inp, guidance_scale=3) |
|
out["images"][0].save(f"public/images/{file_name}.jpg") |
|
result = {"path": f"/images/{file_name}.jpg"} |
|
|
|
|
|
if model_id == "Salesforce/blip-image-captioning-large": |
|
raw_image = load_image(data["img_url"]).convert('RGB') |
|
text = data["text"] |
|
inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"]) |
|
out = pipe.generate(**inputs) |
|
caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True) |
|
result = {"generated text": caption} |
|
if model_id == "ydshieh/vit-gpt2-coco-en": |
|
img_url = data["img_url"] |
|
generated_text = pipe(img_url)[0]['generated_text'] |
|
result = {"generated text": generated_text} |
|
if model_id == "nlpconnect/vit-gpt2-image-captioning": |
|
image = load_image(data["img_url"]).convert("RGB") |
|
pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(pipes[model_id]["device"]) |
|
generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1}) |
|
generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
result = {"generated text": generated_text} |
|
|
|
if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten": |
|
image = load_image(data["img_url"]).convert("RGB") |
|
pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(pipes[model_id]["device"]) |
|
generated_ids = pipe.generate(pixel_values) |
|
generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
result = {"generated text": generated_text} |
|
|
|
|
|
if model_id == "runwayml/stable-diffusion-v1-5": |
|
file_name = str(uuid.uuid4())[:4] |
|
text = data["text"] |
|
out = pipe(prompt=text) |
|
out["images"][0].save(f"public/images/{file_name}.jpg") |
|
result = {"path": f"/images/{file_name}.jpg"} |
|
|
|
|
|
if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101": |
|
img_url = data["img_url"] |
|
open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"] |
|
result = pipe(img_url, candidate_labels=open_types) |
|
|
|
|
|
if model_id == "dandelin/vilt-b32-finetuned-vqa": |
|
question = data["text"] |
|
img_url = data["img_url"] |
|
result = pipe(question=question, image=img_url) |
|
|
|
|
|
if model_id == "impira/layoutlm-document-qa": |
|
question = data["text"] |
|
img_url = data["img_url"] |
|
result = pipe(img_url, question) |
|
|
|
|
|
if model_id == "Intel/dpt-large": |
|
output = pipe(data["img_url"]) |
|
image = output['depth'] |
|
name = str(uuid.uuid4())[:4] |
|
image.save(f"public/images/{name}.jpg") |
|
result = {"path": f"/images/{name}.jpg"} |
|
|
|
if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large": |
|
image = load_image(data["img_url"]) |
|
inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = pipe(**inputs) |
|
predicted_depth = outputs.predicted_depth |
|
prediction = torch.nn.functional.interpolate( |
|
predicted_depth.unsqueeze(1), |
|
size=image.size[::-1], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
output = prediction.squeeze().cpu().numpy() |
|
formatted = (output * 255 / np.max(output)).astype("uint8") |
|
image = Image.fromarray(formatted) |
|
name = str(uuid.uuid4())[:4] |
|
image.save(f"public/images/{name}.jpg") |
|
result = {"path": f"/images/{name}.jpg"} |
|
|
|
|
|
if model_id == "espnet/kan-bayashi_ljspeech_vits": |
|
text = data["text"] |
|
wav = pipe(text)["wav"] |
|
name = str(uuid.uuid4())[:4] |
|
sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16") |
|
result = {"path": f"/audios/{name}.wav"} |
|
|
|
if model_id == "microsoft/speecht5_tts": |
|
text = data["text"] |
|
inputs = pipes[model_id]["processor"](text=text, return_tensors="pt") |
|
embeddings_dataset = pipes[model_id]["embeddings_dataset"] |
|
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"]) |
|
pipes[model_id]["vocoder"].to(pipes[model_id]["device"]) |
|
speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"]) |
|
name = str(uuid.uuid4())[:4] |
|
sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000) |
|
result = {"path": f"/audios/{name}.wav"} |
|
|
|
|
|
if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr": |
|
audio_url = data["audio_url"] |
|
result = { "text": pipe(audio_url)["text"]} |
|
|
|
|
|
if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": |
|
audio_url = data["audio_url"] |
|
wav, sr = torchaudio.load(audio_url) |
|
with torch.no_grad(): |
|
result_wav = pipe(wav.to(pipes[model_id]["device"])) |
|
name = str(uuid.uuid4())[:4] |
|
sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr) |
|
result = {"path": f"/audios/{name}.wav"} |
|
|
|
if model_id == "microsoft/speecht5_vc": |
|
audio_url = data["audio_url"] |
|
wav, sr = torchaudio.load(audio_url) |
|
inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt") |
|
embeddings_dataset = pipes[model_id]["embeddings_dataset"] |
|
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) |
|
pipes[model_id]["vocoder"].to(pipes[model_id]["device"]) |
|
speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"]) |
|
name = str(uuid.uuid4())[:4] |
|
sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000) |
|
result = {"path": f"/audios/{name}.wav"} |
|
|
|
|
|
if model_id == "facebook/detr-resnet-50-panoptic": |
|
result = [] |
|
segments = pipe(data["img_url"]) |
|
image = load_image(data["img_url"]) |
|
|
|
colors = [] |
|
for i in range(len(segments)): |
|
colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50)) |
|
|
|
for segment in segments: |
|
mask = segment["mask"] |
|
mask = mask.convert('L') |
|
layer = Image.new('RGBA', mask.size, colors[i]) |
|
image.paste(layer, (0, 0), mask) |
|
name = str(uuid.uuid4())[:4] |
|
image.save(f"public/images/{name}.jpg") |
|
result = {"path": f"/images/{name}.jpg"} |
|
|
|
if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade": |
|
image = load_image(data["img_url"]) |
|
inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"]) |
|
outputs = pipe(**inputs) |
|
result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
predicted_panoptic_map = result["segmentation"].cpu().numpy() |
|
predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8)) |
|
name = str(uuid.uuid4())[:4] |
|
predicted_panoptic_map.save(f"public/images/{name}.jpg") |
|
result = {"path": f"/images/{name}.jpg"} |
|
|
|
except Exception as e: |
|
print(e) |
|
traceback.print_exc() |
|
result = {"error": {"message": "Error when running the model inference."}} |
|
|
|
if "device" in pipes[model_id]: |
|
try: |
|
pipe.to("cpu") |
|
torch.cuda.empty_cache() |
|
except: |
|
pipe.device = torch.device("cpu") |
|
pipe.model.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
pipes[model_id]["using"] = False |
|
|
|
if result is None: |
|
result = {"error": {"message": "model not found"}} |
|
|
|
end = time.time() |
|
during = end - start |
|
print(f"[ complete {model_id} ] {during}s") |
|
print(f"[ result {model_id} ] {result}") |
|
|
|
return result |
|
|