zero123plus-v1.2 / handler.py
maple-shaft's picture
Update handler.py
064894c verified
from typing import Dict, List, Any
import os
import torch
from PIL import Image
import dotenv
import base64
import io
from diffusers import DiffusionPipeline # pyright: ignore[reportPrivateImportUsage]
dotenv.load_dotenv()
def convert_b64_to_image(from_str: str) -> Image.Image:
print(">>> call convert_b64_to_image", flush=True)
try:
data: bytes = base64.b64decode(from_str)
with io.BytesIO(data) as bio:
imgfile = Image.open(bio, formats=["PNG"])
imgfile.load()
return imgfile
except Exception as e:
print(e, flush=True)
raise e
def convert_image_to_b64(from_img: Image.Image) -> str:
print(">>> call convert_image_to_b64", flush=True)
try:
with io.BytesIO() as buffer:
from_img.save(buffer, format="PNG")
byte_data: bytes = buffer.getvalue()
return base64.b64encode(byte_data).decode("utf-8")
except Exception as e:
print(e, flush=True)
raise e
class HFMultiViewGen:
def __init__(self,
hf_token: str,
mv_model: str = "maple-shaft/zero123plus-v1.2",
mv_custom_pipeline: str = "sudo-ai/zero123plus-pipeline",
gen_custom_pipeline: str = "",
repo_dir: str = "/repository",
debug: bool = False):
self.debug = debug
self.hf_token = hf_token
self.mv_model = mv_model
self.mv_custom_pipeline = mv_custom_pipeline
self.repo_dir = repo_dir
print(f"torch.cuda.is_available() = {torch.cuda.is_available()}")
torch.cuda.synchronize()
print("GPU SYNC OK", flush=True)
self.pipe = DiffusionPipeline.from_pretrained(
self.mv_model,
cache_dir=self.repo_dir,
token=self.hf_token,
custom_pipeline=self.mv_custom_pipeline,
dtype=torch.float16
).to("cuda")
def generate_multiview(self, initial: Image.Image) -> dict[str, Image.Image]:
print(">>> generate_multiview", flush=True)
print("allocated second pipe to gpu", flush=True)
# --- prepare image properly ---
img = initial.convert("RGB")
print("converted the image to RGB", flush=True)
mv_result : List[Image.Image] = self.pipe(
image=img,
width=640,
height=960,
num_inference_steps=28,
guidance_scale=4.0,
num_images_per_prompt=1
).images # pyright: ignore[reportCallIssue]
print("mv_result", repr(mv_result), flush=True)
# The resulting file comes back as a 2x3 tiled PNG image, we will need to split it into a set of images
tile_w = 320.0 # img.width / 2.0
tile_h = 320.0 # img.height / 3.0
right_tile = (tile_w, 0.0, tile_w * 2.0, tile_h)
back_tile = (tile_w, tile_h, tile_w * 2.0, tile_h * 2.0)
left_tile = (0, tile_h * 2.0, tile_w, tile_h * 3.0)
ret = {
"front": img,
"right": mv_result[0].crop(right_tile),
"back": mv_result[0].crop(back_tile),
"left": mv_result[0].crop(left_tile)
}
return ret
class EndpointHandler():
def __init__(self, path=""):
self.hf_token = os.environ["HUGGINGFACE_TOKEN"]
self.repo_dir = os.environ["HF_HUB_CACHE"] if not path else path
self.hf_gen = HFMultiViewGen(hf_token=self.hf_token, repo_dir=self.repo_dir)
def convert(self, fromval: dict[str, Image.Image]) -> dict[str, str]:
ret: dict[str, str] = {}
for k,v in fromval.items():
ret[k] = convert_image_to_b64(v)
return ret
def __call__(self, data: Dict[str, Any]):
print("Entered __call__!!! ", repr(data), flush=True)
ret: dict[str, Any] = {}
try:
img_str = data['inputs']
print(f"Initial image: {img_str}", flush=True)
img: Image.Image = convert_b64_to_image(img_str)
print("Converted to image", repr(img), flush=True)
mv: dict[str, Image.Image] = self.hf_gen.generate_multiview(initial=img)
print(f"Mv Image: {mv}", flush=True)
mv_str: Dict[str,str] = self.convert(mv)
ret["output"] = mv_str
return ret
except Exception as e:
print(e)
raise e