InPaintAPI / app.py
kael558's picture
Update app.py
6e223a1
raw
history blame contribute delete
No virus
2.18 kB
import gradio as gr
from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline
"""
auth_token = os.environ.get("API_TOKEN") or True
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
def predict(dict, prompt=""):
init_image = dict["image"].convert("RGB").resize((512, 512))
mask = dict["mask"].convert("RGB").resize((512, 512))
output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
"""
def add_text(text, image, image_process_mode, request: gr.Request):
text = text[:1536] # Hard cut-off
if image is not None:
print(image)
text = text[:1200] # Hard cut-off for images
if "<image>" not in text:
# text = '<Image><image></Image>' + text
text = text + "\n<image>"
text = (text, image, image_process_mode)
print(text)
image_blocks = gr.Blocks()
with image_blocks as demo:
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image",
visible=False,
)
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
imagebox = gr.Image(type="pil")
submit_btn = gr.Button(value="Send", variant="primary", interactive=False)
submit_btn.click(
add_text,
[textbox, imagebox, image_process_mode],
[],
)
image_blocks.launch()