ihsanvp's picture
initial - v0
bc05b03
raw
history blame contribute delete
No virus
870 Bytes
import gradio as gr
from PIL import Image
from models.segmentation import SamSegmentationModel
from models.inpainting import KandingskyInpaintingModel
from models.product import ProductBackgroundModifier
import torch
def generate(image: Image.Image, prompt: str):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ProductBackgroundModifier(
segmentation_model=SamSegmentationModel(
model_type="vit_h",
checkpoint_path="model_checkpoints/sam_vit.pth",
device=device,
),
inpainting_model=KandingskyInpaintingModel(),
device=device
)
generated = model.generate(image=image, prompt=prompt)
return generated
gr.Interface(
fn=generate,
inputs=[
gr.Image(type="pil"),
gr.Text()
],
outputs=gr.Image(type="pil"),
).launch()