radames's picture
static
ec46b80
raw
history blame
12 kB
import io
import os
from pathlib import Path
import uvicorn
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, Form, Depends, status, Request
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi_utils.tasks import repeat_every
import numpy as np
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
from diffusers.models import AutoencoderKL
from PIL import Image
import gradio as gr
import skimage
import skimage.measure
from utils import *
import boto3
import magic
import sqlite3
import requests
import shortuuid
import re
import time
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME')
LIVEBLOCKS_SECRET = os.environ.get("LIVEBLOCKS_SECRET")
HF_TOKEN = os.environ.get("API_TOKEN") or True
FILE_TYPES = {
'image/png': 'png',
'image/jpeg': 'jpg',
}
DB_PATH = Path("rooms.db")
app = FastAPI()
if not DB_PATH.exists():
print("Creating database")
print("DB_PATH", DB_PATH)
db = sqlite3.connect(DB_PATH)
with open(Path("schema.sql"), "r") as f:
db.executescript(f.read())
db.commit()
db.close()
def get_db():
db = sqlite3.connect(DB_PATH, check_same_thread=False)
db.row_factory = sqlite3.Row
try:
yield db
except Exception:
db.rollback()
finally:
db.close()
s3 = boto3.client(service_name='s3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_KEY)
try:
SAMPLING_MODE = Image.Resampling.LANCZOS
except Exception as e:
SAMPLING_MODE = Image.LANCZOS
blocks = gr.Blocks().queue()
model = {}
STATIC_MASK = Image.open("mask.png")
def get_model():
if "inpaint" not in model:
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
vae=vae,
).to("cuda")
# lms = LMSDiscreteScheduler(
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
# img2img = StableDiffusionImg2ImgPipeline(
# vae=text2img.vae,
# text_encoder=text2img.text_encoder,
# tokenizer=text2img.tokenizer,
# unet=text2img.unet,
# scheduler=lms,
# safety_checker=text2img.safety_checker,
# feature_extractor=text2img.feature_extractor,
# ).to("cuda")
# try:
# total_memory = torch.cuda.get_device_properties(0).total_memory // (
# 1024 ** 3
# )
# if total_memory <= 5:
# inpaint.enable_attention_slicing()
# except:
# pass
model["inpaint"] = inpaint
# model["img2img"] = img2img
return model["inpaint"]
# model["img2img"]
# init model on startup
get_model()
async def run_outpaint(
input_image,
prompt_text,
strength,
guidance,
step,
fill_mode,
room_id,
image_key
):
inpaint = get_model()
sel_buffer = np.array(input_image)
img = sel_buffer[:, :, 0:3]
mask = sel_buffer[:, :, -1]
nmask = 255 - mask
process_size = 512
if nmask.sum() < 1:
print("inpaiting with fixed Mask")
mask = np.array(STATIC_MASK)[:, :, 0]
img, mask = functbl[fill_mode](img, mask)
init_image = Image.fromarray(img)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
mask_image = Image.fromarray(mask)
elif mask.sum() > 0:
print("inpainting")
img, mask = functbl[fill_mode](img, mask)
init_image = Image.fromarray(img)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
mask_image = Image.fromarray(mask)
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
else:
print("text2image")
print("inpainting")
img, mask = functbl[fill_mode](img, mask)
init_image = Image.fromarray(img)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
mask_image = Image.fromarray(mask)
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
with autocast("cuda"):
output = inpaint(
prompt=prompt_text,
image=init_image.resize(
(process_size, process_size), resample=SAMPLING_MODE
),
mask_image=mask_image.resize((process_size, process_size)),
strength=strength,
num_inference_steps=step,
guidance_scale=guidance,
)
image = output["images"][0]
is_nsfw = output["nsfw_content_detected"][0]
image_url = {}
if not is_nsfw:
# print("not nsfw, uploading")
image_url = await upload_file(image, prompt_text, room_id, image_key)
params = {
"is_nsfw": is_nsfw,
"image": image_url
}
return params
with blocks as demo:
with gr.Row():
with gr.Column(scale=3, min_width=270):
sd_prompt = gr.Textbox(
label="Prompt", placeholder="input your prompt here", lines=4
)
with gr.Column(scale=2, min_width=150):
sd_strength = gr.Slider(
label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
)
with gr.Column(scale=1, min_width=150):
sd_step = gr.Number(label="Step", value=50, precision=0)
sd_guidance = gr.Number(label="Guidance", value=7.5)
with gr.Row():
with gr.Column(scale=4, min_width=600):
init_mode = gr.Radio(
label="Init mode",
choices=[
"patchmatch",
"edge_pad",
"cv2_ns",
"cv2_telea",
"gaussian",
"perlin",
],
value="patchmatch",
type="value",
)
model_input = gr.Image(label="Input", type="pil", image_mode="RGBA")
room_id = gr.Textbox(label="Room ID")
image_key = gr.Textbox(label="image_key")
proceed_button = gr.Button("Proceed", elem_id="proceed")
params = gr.JSON()
proceed_button.click(
fn=run_outpaint,
inputs=[
model_input,
sd_prompt,
sd_strength,
sd_guidance,
sd_step,
init_mode,
room_id,
image_key
],
outputs=[params],
)
blocks.config['dev_mode'] = False
app = gr.mount_gradio_app(app, blocks, "/gradio",
gradio_api_url="http://0.0.0.0:7860/gradio/")
def generateAuthToken():
response = requests.get(f"https://liveblocks.io/api/authorize",
headers={"Authorization": f"Bearer {LIVEBLOCKS_SECRET}"})
if response.status_code == 200:
data = response.json()
return data["token"]
else:
raise Exception(response.status_code, response.text)
def get_room_count(room_id: str, jwtToken: str = ''):
response = requests.get(
f"https://liveblocks.net/api/v1/room/{room_id}/users", headers={"Authorization": f"Bearer {jwtToken}", "Content-Type": "application/json"})
if response.status_code == 200:
res = response.json()
if "data" in res:
return len(res["data"])
else:
return 0
raise Exception("Error getting room count")
@ app.on_event("startup")
@ repeat_every(seconds=120)
async def sync_rooms():
# print("Syncing rooms")
try:
jwtToken = generateAuthToken()
for db in get_db():
rooms = db.execute("SELECT * FROM rooms").fetchall()
for row in rooms:
room_id = row["room_id"]
users_count = get_room_count(room_id, jwtToken)
cursor = db.cursor()
cursor.execute(
"UPDATE rooms SET users_count = ? WHERE room_id = ?", (users_count, room_id))
db.commit()
except Exception as e:
print(e)
print("Rooms update failed")
@ app.get('/api/rooms')
async def get_rooms(db: sqlite3.Connection = Depends(get_db)):
rooms = db.execute("SELECT * FROM rooms").fetchall()
return rooms
@ app.post('/api/auth')
async def autorize(request: Request, db: sqlite3.Connection = Depends(get_db)):
data = await request.json()
room = data["room"]
payload = {
"userId": str(shortuuid.uuid()),
"userInfo": {
"name": "Anon"
}}
response = requests.post(f"https://api.liveblocks.io/v2/rooms/{room}/authorize",
headers={"Authorization": f"Bearer {LIVEBLOCKS_SECRET}"}, json=payload)
if response.status_code == 200:
# user in, incremente room count
# cursor = db.cursor()
# cursor.execute(
# "UPDATE rooms SET users_count = users_count + 1 WHERE room_id = ?", (room,))
# db.commit()
sync_rooms()
return response.json()
else:
raise Exception(response.status_code, response.text)
def slugify(value):
value = re.sub(r'[^\w\s-]', '', value).strip().lower()
out = re.sub(r'[-\s]+', '-', value)
return out[:400]
async def upload_file(image: Image.Image, prompt: str, room_id: str, image_key: str):
room_id = room_id.strip() or "uploads"
image_key = image_key.strip() or ""
image = image.convert('RGB')
# print("Uploading file from predict")
temp_file = io.BytesIO()
image.save(temp_file, format="JPEG")
temp_file.seek(0)
id = shortuuid.uuid()
date = int(time.time())
prompt_slug = slugify(prompt)
filename = f"{date}-{id}-{image_key}-{prompt_slug}.jpg"
s3.upload_fileobj(Fileobj=temp_file, Bucket=AWS_S3_BUCKET_NAME, Key=f"{room_id}/" +
filename, ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"})
temp_file.close()
out = {"url": f'https://d26smi9133w0oo.cloudfront.net/{room_id}/{filename}',
"filename": filename}
return out
@ app.post('/api/uploadfile')
async def create_upload_file(file: UploadFile):
contents = await file.read()
file_size = len(contents)
if not 0 < file_size < 20E+06:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Supported file size is less than 2 MB'
)
file_type = magic.from_buffer(contents, mime=True)
if file_type.lower() not in FILE_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Unsupported file type {file_type}. Supported types are {FILE_TYPES}'
)
temp_file = io.BytesIO()
temp_file.write(contents)
temp_file.seek(0)
s3.upload_fileobj(Fileobj=temp_file, Bucket=AWS_S3_BUCKET_NAME, Key="community/" +
file.filename, ExtraArgs={"ContentType": file.content_type, "CacheControl": "max-age=31536000"})
temp_file.close()
return {"url": f'https://d26smi9133w0oo.cloudfront.net/community/{file.filename}', "filename": file.filename}
app.mount("/", StaticFiles(directory="../static", html=True), name="static")
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860,
log_level="debug", reload=False)