Spaces:
Paused
Paused
from app.db import get_db | |
from app.config import DEVICE | |
from app.core import schemas, crud | |
from app.security import get_current_user | |
from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations | |
import base64 | |
from io import BytesIO | |
from sqlalchemy.orm import Session | |
from typing import Annotated, List | |
from fastapi import APIRouter, Depends, HTTPException, Request | |
import torch | |
import numpy as np | |
from PIL import Image | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import normalize | |
router = APIRouter() | |
def decode_image(image): | |
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB") | |
def encode_image(image): | |
bytes = BytesIO() | |
image.save(bytes, format="PNG") | |
return base64.b64encode(bytes.getvalue()) | |
def create_prompt( | |
subject, | |
medium, | |
style, | |
artist, | |
website, | |
resolution, | |
additional_details, | |
color, | |
lightning, | |
): | |
if not subject: | |
return None | |
if medium: | |
subject = f"{medium} of {subject}" | |
if style: | |
subject = f"{subject}, {style}" | |
if artist: | |
subject = f"{subject}, by {artist}" | |
if website: | |
subject = f"{subject}, {website}" | |
if resolution: | |
subject = f"{subject}, {resolution}" | |
if additional_details: | |
subject = f"{subject}, {additional_details}" | |
if color: | |
subject = f"{subject}, {color}" | |
if lightning: | |
subject = f"{subject}, {lightning}" | |
return subject | |
def text_image( | |
model: Request, | |
request: TextImage, | |
db: Annotated[Session, Depends(get_db)], | |
current_user: Annotated[schemas.User, Depends(get_current_user)], | |
): | |
if not current_user.is_active: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
generator = torch.manual_seed(request.seed) | |
prompt = create_prompt( | |
request.prompt, | |
medium=request.medium, | |
style=request.style, | |
artist=request.artist, | |
website=request.website, | |
resolution=request.resolution, | |
additional_details=request.additional_details, | |
color=request.color, | |
lightning=request.lightning, | |
) | |
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) | |
image = model.state.ti_pipe( | |
prompt, | |
num_inference_steps=request.num_inference_steps, | |
guidance_scale=request.guidance_scale, | |
generator=generator, | |
negative_prompt=request.negative_prompt, | |
).images[0] | |
return encode_image(image) | |
def image_image( | |
model: Request, | |
request: ImageImage, | |
db: Annotated[Session, Depends(get_db)], | |
current_user: Annotated[schemas.User, Depends(get_current_user)], | |
): | |
if not current_user.is_active: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
generator = torch.manual_seed(request.seed) | |
prompt = create_prompt( | |
request.prompt, | |
medium=request.medium, | |
style=request.style, | |
artist=request.artist, | |
website=request.website, | |
resolution=request.resolution, | |
additional_details=request.additional_details, | |
color=request.color, | |
lightning=request.lightning, | |
) | |
image = decode_image(request.image) | |
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) | |
image = model.state.ii_pipe( | |
prompt, | |
image=image, | |
num_inference_steps=request.num_inference_steps, | |
guidance_scale=request.guidance_scale, | |
image_guidance_scale=request.image_guidance_scale, | |
generator=generator, | |
negative_prompt=request.negative_prompt, | |
).images[0] | |
return encode_image(image) | |
def background_removal( | |
model: Request, | |
request: BackgroundRemoval, | |
current_user: Annotated[schemas.User, Depends(get_current_user)], | |
): | |
if not current_user.is_active: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
image = decode_image(request.image) | |
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: | |
if len(im.shape) < 3: | |
im = im[:, :, np.newaxis] | |
# orig_im_size=im.shape[0:2] | |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) | |
im_tensor = F.interpolate( | |
torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear" | |
) | |
image = torch.divide(im_tensor, 255.0) | |
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
return image | |
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: | |
result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0) | |
ma = torch.max(result) | |
mi = torch.min(result) | |
result = (result - mi) / (ma - mi) | |
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) | |
im_array = np.squeeze(im_array) | |
return im_array | |
# prepare input | |
model_input_size = [1024, 1024] | |
orig_im = np.array(image) | |
orig_im_size = orig_im.shape[0:2] | |
image = preprocess_image(orig_im, model_input_size).to(DEVICE) | |
# inference | |
result = model.state.br_model(image) | |
# post process | |
result_image = postprocess_image(result[0][0], orig_im_size) | |
# save result | |
pil_im = Image.fromarray(result_image) | |
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) | |
orig_image = Image.fromarray(orig_im) | |
no_bg_image.paste(orig_image, mask=pil_im) | |
return encode_image(no_bg_image) | |
def image_variations( | |
model: Request, | |
request: ImageVariations, | |
db: Annotated[Session, Depends(get_db)], | |
current_user: Annotated[schemas.User, Depends(get_current_user)], | |
): | |
if not current_user.is_active: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
prompt = create_prompt( | |
request.prompt, | |
medium=request.medium, | |
style=request.style, | |
artist=request.artist, | |
website=request.website, | |
resolution=request.resolution, | |
additional_details=request.additional_details, | |
color=request.color, | |
lightning=request.lightning, | |
) | |
image = decode_image(request.image) | |
image.resize((256, 256)) | |
if prompt: | |
prompt = f"best quality, high quality, {prompt}" | |
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt) | |
else: | |
request.scale = 1.0 | |
images = model.state.iv_model.generate( | |
pil_image=image, | |
num_samples=request.num_samples, | |
num_inference_steps=request.num_inference_steps, | |
seed=request.seed, | |
prompt=prompt, | |
negative_prompt=request.negative_prompt, | |
scale=request.scale, | |
guidance_scale=request.guidance_scale, | |
) | |
images = [encode_image(image) for image in images] | |
return images | |