|
|
from typing import Dict, List, Any |
|
|
import os |
|
|
import torch |
|
|
from PIL import Image |
|
|
import dotenv |
|
|
import base64 |
|
|
import io |
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
def convert_b64_to_image(from_str: str) -> Image.Image: |
|
|
print(">>> call convert_b64_to_image", flush=True) |
|
|
try: |
|
|
data: bytes = base64.b64decode(from_str) |
|
|
with io.BytesIO(data) as bio: |
|
|
imgfile = Image.open(bio, formats=["PNG"]) |
|
|
imgfile.load() |
|
|
return imgfile |
|
|
|
|
|
except Exception as e: |
|
|
print(e, flush=True) |
|
|
raise e |
|
|
|
|
|
def convert_image_to_b64(from_img: Image.Image) -> str: |
|
|
print(">>> call convert_image_to_b64", flush=True) |
|
|
try: |
|
|
with io.BytesIO() as buffer: |
|
|
from_img.save(buffer, format="PNG") |
|
|
byte_data: bytes = buffer.getvalue() |
|
|
return base64.b64encode(byte_data).decode("utf-8") |
|
|
except Exception as e: |
|
|
print(e, flush=True) |
|
|
raise e |
|
|
|
|
|
class HFMultiViewGen: |
|
|
|
|
|
def __init__(self, |
|
|
hf_token: str, |
|
|
mv_model: str = "maple-shaft/zero123plus-v1.2", |
|
|
mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline", |
|
|
gen_custom_pipeline: str = "", |
|
|
repo_dir: str = "/repository", |
|
|
debug: bool = False): |
|
|
self.debug = debug |
|
|
self.hf_token = hf_token |
|
|
self.mv_model = mv_model |
|
|
self.mv_custom_pipeline = mv_custom_pipeline |
|
|
self.repo_dir = repo_dir |
|
|
|
|
|
print(f"torch.cuda.is_available() = {torch.cuda.is_available()}") |
|
|
torch.cuda.synchronize() |
|
|
print("GPU SYNC OK", flush=True) |
|
|
|
|
|
self.pipe = DiffusionPipeline.from_pretrained( |
|
|
self.mv_model, |
|
|
cache_dir=self.repo_dir, |
|
|
token=self.hf_token, |
|
|
custom_pipeline=self.mv_custom_pipeline, |
|
|
dtype=torch.float16 |
|
|
).to("cuda") |
|
|
|
|
|
def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]: |
|
|
print(">>> generate_multiview", flush=True) |
|
|
|
|
|
print("allocated second pipe to gpu", flush=True) |
|
|
|
|
|
img = initial.convert("RGB") |
|
|
|
|
|
print("converted the image to RGB", flush=True) |
|
|
|
|
|
mv_result : List[Image.Image] = self.pipe( |
|
|
image=img, |
|
|
width=640, |
|
|
height=960, |
|
|
num_inference_steps=28, |
|
|
guidance_scale=4.0, |
|
|
num_images_per_prompt=1 |
|
|
).images |
|
|
|
|
|
print("mv_result", repr(mv_result), flush=True) |
|
|
|
|
|
|
|
|
tile_w = 320.0 |
|
|
tile_h = 320.0 |
|
|
right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h) |
|
|
back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0) |
|
|
left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0) |
|
|
ret = { |
|
|
"front": img, |
|
|
"right": mv_result[0].crop(right_tile), |
|
|
"back": mv_result[0].crop(back_tile), |
|
|
"left": mv_result[0].crop(left_tile) |
|
|
} |
|
|
|
|
|
return ret |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
self.hf_token = os.environ["HUGGINGFACE_TOKEN"] |
|
|
self.repo_dir = os.environ["HF_HUB_CACHE"] if not path else path |
|
|
self.hf_gen = HFMultiViewGen(hf_token=self.hf_token, repo_dir=self.repo_dir) |
|
|
|
|
|
def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]: |
|
|
ret: dict[str, str] = {} |
|
|
for k,v in fromval.items(): |
|
|
ret[k] = convert_image_to_b64(v) |
|
|
return ret |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
|
print("Entered __call__!!! ", repr(data), flush=True) |
|
|
ret: dict[str, Any] = {} |
|
|
try: |
|
|
img_str = data['inputs'] |
|
|
print(f"Initial image: {img_str}", flush=True) |
|
|
img: Image.Image = convert_b64_to_image(img_str) |
|
|
print("Converted to image", repr(img), flush=True) |
|
|
mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img) |
|
|
print(f"Mv Image: {mv}", flush=True) |
|
|
mv_str: Dict[str,str] = self.convert(mv) |
|
|
ret["output"] = mv_str |
|
|
return ret |
|
|
except Exception as e: |
|
|
print(e) |
|
|
raise e |
|
|
|
|
|
|