import gradio as gr from PIL import Image from models.segmentation import SamSegmentationModel from models.inpainting import KandingskyInpaintingModel from models.product import ProductBackgroundModifier import torch def generate(image: Image.Image, prompt: str): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = ProductBackgroundModifier( segmentation_model=SamSegmentationModel( model_type="vit_h", checkpoint_path="model_checkpoints/sam_vit.pth", device=device, ), inpainting_model=KandingskyInpaintingModel(), device=device ) generated = model.generate(image=image, prompt=prompt) return generated gr.Interface( fn=generate, inputs=[ gr.Image(type="pil"), gr.Text() ], outputs=gr.Image(type="pil"), ).launch()