File size: 870 Bytes
bc05b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from PIL import Image
from models.segmentation import SamSegmentationModel
from models.inpainting import KandingskyInpaintingModel
from models.product import ProductBackgroundModifier
import torch

def generate(image: Image.Image, prompt: str):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = ProductBackgroundModifier(
        segmentation_model=SamSegmentationModel(
            model_type="vit_h", 
            checkpoint_path="model_checkpoints/sam_vit.pth", 
            device=device,
        ),
        inpainting_model=KandingskyInpaintingModel(),
        device=device
    )
    generated = model.generate(image=image, prompt=prompt)
    return generated

gr.Interface(
    fn=generate, 
    inputs=[
        gr.Image(type="pil"),
        gr.Text()
    ], 
    outputs=gr.Image(type="pil"),
).launch()