cg-gen-ai / app /api /generate.py
khawir's picture
Update app/api/generate.py
e6d2554 verified
raw
history blame contribute delete
No virus
7.15 kB
from app.db import get_db
from app.config import DEVICE
from app.core import schemas, crud
from app.security import get_current_user
from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
import base64
from io import BytesIO
from sqlalchemy.orm import Session
from typing import Annotated, List
from fastapi import APIRouter, Depends, HTTPException, Request
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
router = APIRouter()
def decode_image(image):
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
def encode_image(image):
bytes = BytesIO()
image.save(bytes, format="PNG")
return base64.b64encode(bytes.getvalue())
def create_prompt(
subject,
medium,
style,
artist,
website,
resolution,
additional_details,
color,
lightning,
):
if not subject:
return None
if medium:
subject = f"{medium} of {subject}"
if style:
subject = f"{subject}, {style}"
if artist:
subject = f"{subject}, by {artist}"
if website:
subject = f"{subject}, {website}"
if resolution:
subject = f"{subject}, {resolution}"
if additional_details:
subject = f"{subject}, {additional_details}"
if color:
subject = f"{subject}, {color}"
if lightning:
subject = f"{subject}, {lightning}"
return subject
@router.post("/text-image/", response_model=str)
def text_image(
model: Request,
request: TextImage,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[schemas.User, Depends(get_current_user)],
):
if not current_user.is_active:
raise HTTPException(status_code=403, detail="Forbidden")
generator = torch.manual_seed(request.seed)
prompt = create_prompt(
request.prompt,
medium=request.medium,
style=request.style,
artist=request.artist,
website=request.website,
resolution=request.resolution,
additional_details=request.additional_details,
color=request.color,
lightning=request.lightning,
)
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
with torch.inference_mode():
image = model.state.ti_pipe(
prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
generator=generator,
negative_prompt=request.negative_prompt,
).images[0]
return encode_image(image)
@router.post("/image-image/", response_model=str)
def image_image(
model: Request,
request: ImageImage,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[schemas.User, Depends(get_current_user)],
):
if not current_user.is_active:
raise HTTPException(status_code=403, detail="Forbidden")
generator = torch.manual_seed(request.seed)
prompt = create_prompt(
request.prompt,
medium=request.medium,
style=request.style,
artist=request.artist,
website=request.website,
resolution=request.resolution,
additional_details=request.additional_details,
color=request.color,
lightning=request.lightning,
)
image = decode_image(request.image)
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
with torch.inference_mode():
image = model.state.ii_pipe(
prompt,
image=image,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
image_guidance_scale=request.image_guidance_scale,
generator=generator,
negative_prompt=request.negative_prompt,
).images[0]
return encode_image(image)
@router.post("/background-removal/", response_model=str)
def background_removal(
model: Request,
request: BackgroundRemoval,
current_user: Annotated[schemas.User, Depends(get_current_user)],
):
if not current_user.is_active:
raise HTTPException(status_code=403, detail="Forbidden")
image = decode_image(request.image)
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
# orig_im_size=im.shape[0:2]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.interpolate(
torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return image
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
# prepare input
model_input_size = [1024, 1024]
orig_im = np.array(image)
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size).to(DEVICE)
# inference
result = model.state.br_model(image)
# post process
result_image = postprocess_image(result[0][0], orig_im_size)
# save result
pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
orig_image = Image.fromarray(orig_im)
no_bg_image.paste(orig_image, mask=pil_im)
return encode_image(no_bg_image)
@router.post("/image-variations/", response_model=List[str])
def image_variations(
model: Request,
request: ImageVariations,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[schemas.User, Depends(get_current_user)],
):
if not current_user.is_active:
raise HTTPException(status_code=403, detail="Forbidden")
prompt = create_prompt(
request.prompt,
medium=request.medium,
style=request.style,
artist=request.artist,
website=request.website,
resolution=request.resolution,
additional_details=request.additional_details,
color=request.color,
lightning=request.lightning,
)
image = decode_image(request.image)
image.resize((256, 256))
if prompt:
prompt = f"best quality, high quality, {prompt}"
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
else:
request.scale = 1.0
images = model.state.iv_model.generate(
pil_image=image,
num_samples=request.num_samples,
num_inference_steps=request.num_inference_steps,
seed=request.seed,
prompt=prompt,
negative_prompt=request.negative_prompt,
scale=request.scale,
guidance_scale=request.guidance_scale,
)
images = [encode_image(image) for image in images]
return images