from app.db import get_db from app.config import DEVICE from app.core import schemas, crud from app.security import get_current_user from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations import base64 from io import BytesIO from sqlalchemy.orm import Session from typing import Annotated, List from fastapi import APIRouter, Depends, HTTPException, Request import torch import numpy as np from PIL import Image import torch.nn.functional as F from torchvision.transforms.functional import normalize router = APIRouter() def decode_image(image): return Image.open(BytesIO(base64.b64decode(image))).convert("RGB") def encode_image(image): bytes = BytesIO() image.save(bytes, format="PNG") return base64.b64encode(bytes.getvalue()) def create_prompt( subject, medium, style, artist, website, resolution, additional_details, color, lightning, ): if not subject: return None if medium: subject = f"{medium} of {subject}" if style: subject = f"{subject}, {style}" if artist: subject = f"{subject}, by {artist}" if website: subject = f"{subject}, {website}" if resolution: subject = f"{subject}, {resolution}" if additional_details: subject = f"{subject}, {additional_details}" if color: subject = f"{subject}, {color}" if lightning: subject = f"{subject}, {lightning}" return subject @router.post("/text-image/", response_model=str) def text_image( model: Request, request: TextImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)], ): if not current_user.is_active: raise HTTPException(status_code=403, detail="Forbidden") generator = torch.manual_seed(request.seed) prompt = create_prompt( request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning, ) crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) with torch.inference_mode(): image = model.state.ti_pipe( prompt, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, generator=generator, negative_prompt=request.negative_prompt, ).images[0] return encode_image(image) @router.post("/image-image/", response_model=str) def image_image( model: Request, request: ImageImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)], ): if not current_user.is_active: raise HTTPException(status_code=403, detail="Forbidden") generator = torch.manual_seed(request.seed) prompt = create_prompt( request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning, ) image = decode_image(request.image) crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) with torch.inference_mode(): image = model.state.ii_pipe( prompt, image=image, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, image_guidance_scale=request.image_guidance_scale, generator=generator, negative_prompt=request.negative_prompt, ).images[0] return encode_image(image) @router.post("/background-removal/", response_model=str) def background_removal( model: Request, request: BackgroundRemoval, current_user: Annotated[schemas.User, Depends(get_current_user)], ): if not current_user.is_active: raise HTTPException(status_code=403, detail="Forbidden") image = decode_image(request.image) def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: if len(im.shape) < 3: im = im[:, :, np.newaxis] # orig_im_size=im.shape[0:2] im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) im_tensor = F.interpolate( torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear" ) image = torch.divide(im_tensor, 255.0) image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) return image def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) im_array = np.squeeze(im_array) return im_array # prepare input model_input_size = [1024, 1024] orig_im = np.array(image) orig_im_size = orig_im.shape[0:2] image = preprocess_image(orig_im, model_input_size).to(DEVICE) # inference result = model.state.br_model(image) # post process result_image = postprocess_image(result[0][0], orig_im_size) # save result pil_im = Image.fromarray(result_image) no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) orig_image = Image.fromarray(orig_im) no_bg_image.paste(orig_image, mask=pil_im) return encode_image(no_bg_image) @router.post("/image-variations/", response_model=List[str]) def image_variations( model: Request, request: ImageVariations, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)], ): if not current_user.is_active: raise HTTPException(status_code=403, detail="Forbidden") prompt = create_prompt( request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning, ) image = decode_image(request.image) image.resize((256, 256)) if prompt: prompt = f"best quality, high quality, {prompt}" crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) else: request.scale = 1.0 images = model.state.iv_model.generate( pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps, seed=request.seed, prompt=prompt, negative_prompt=request.negative_prompt, scale=request.scale, guidance_scale=request.guidance_scale, ) images = [encode_image(image) for image in images] return images