ihsanvp commited on
Commit
bc05b03
1 Parent(s): 0555d33

initial - v0

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ flagged/
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from models.segmentation import SamSegmentationModel
4
+ from models.inpainting import KandingskyInpaintingModel
5
+ from models.product import ProductBackgroundModifier
6
+ import torch
7
+
8
+ def generate(image: Image.Image, prompt: str):
9
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
+ model = ProductBackgroundModifier(
11
+ segmentation_model=SamSegmentationModel(
12
+ model_type="vit_h",
13
+ checkpoint_path="model_checkpoints/sam_vit.pth",
14
+ device=device,
15
+ ),
16
+ inpainting_model=KandingskyInpaintingModel(),
17
+ device=device
18
+ )
19
+ generated = model.generate(image=image, prompt=prompt)
20
+ return generated
21
+
22
+ gr.Interface(
23
+ fn=generate,
24
+ inputs=[
25
+ gr.Image(type="pil"),
26
+ gr.Text()
27
+ ],
28
+ outputs=gr.Image(type="pil"),
29
+ ).launch()
model_checkpoints/sam_vit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
models/inpainting.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoPipelineForInpainting
3
+ from torchvision.transforms.functional import to_pil_image
4
+ from PIL import Image
5
+
6
+ class InpaintingModel:
7
+ def __init__(self) -> None:
8
+ pass
9
+ def generate(self, image: torch.Tensor, mask_image: torch.Tensor, prompt: str) -> Image.Image:
10
+ pass
11
+
12
+ class KandingskyInpaintingModel(InpaintingModel):
13
+ def __init__(
14
+ self,
15
+ device = torch.device("cpu"),
16
+ ) -> None:
17
+ super().__init__()
18
+ self.device = device
19
+ self.model = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16)
20
+ self.model.enable_model_cpu_offload()
21
+ self.negative_prompt = "deformed, ugly, disfigured"
22
+
23
+ def generate(self, image: Image.Image, mask_image: Image.Image, prompt: str) -> Image.Image:
24
+ output = self.model(prompt=prompt, negative_prompt=self.negative_prompt, image=image, mask_image=mask_image)
25
+ return output.images[0]
26
+
models/product.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from torchvision.transforms.functional import to_pil_image
4
+ from models import segmentation, inpainting
5
+ from PIL import Image
6
+
7
+ class ProductBackgroundModifier:
8
+ def __init__(
9
+ self,
10
+ segmentation_model: segmentation.SegmentationModel,
11
+ inpainting_model: inpainting.InpaintingModel,
12
+ device = torch.device("cpu"),
13
+ ) -> None:
14
+ self.segmentation_model = segmentation_model
15
+ self.inpainting_model = inpainting_model
16
+ self.device = device
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Resize(1024),
20
+ transforms.CenterCrop((1024, 1024))
21
+ ])
22
+
23
+ def generate(self, image: Image.Image, prompt: str) -> Image.Image:
24
+ image_tensor = self.transform(image).to(self.device)
25
+ mask_image = self.segmentation_model.generate(image_tensor)
26
+ mask_image.show()
27
+ generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt)
28
+ return generated_image
29
+
30
+
models/segmentation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms.functional import to_pil_image
3
+ from segment_anything import SamPredictor, sam_model_registry
4
+ from PIL import Image
5
+
6
+ class SegmentationModel:
7
+ def __init__(self) -> None:
8
+ pass
9
+ def generate(self, image: torch.Tensor) -> Image.Image:
10
+ pass
11
+
12
+ class SamSegmentationModel(SegmentationModel):
13
+ def __init__(
14
+ self,
15
+ model_type: str,
16
+ checkpoint_path: str,
17
+ device = torch.device("cpu"),
18
+ ) -> None:
19
+ super().__init__()
20
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
21
+ sam.to(device)
22
+ self.device = device
23
+ self.model = SamPredictor(sam)
24
+
25
+ def generate(self, image: torch.Tensor) -> Image.Image:
26
+ _, H, W = image.size()
27
+ image = image.unsqueeze(0)
28
+ self.model.set_torch_image(image, original_image_size=(H, W))
29
+ center_point = [H / 2, W / 2]
30
+ input_point = torch.tensor([[center_point]]).to(self.device)
31
+ input_label = torch.tensor([[1]]).to(self.device)
32
+ masks, scores, logits = self.model.predict_torch(
33
+ point_coords=input_point,
34
+ point_labels=input_label,
35
+ boxes=None,
36
+ multimask_output=True
37
+ )
38
+ masks = masks.squeeze(0)
39
+ scores = scores.squeeze(0)
40
+ bmask = masks[torch.argmax(scores).item()]
41
+ mask_float = 1.0 - bmask.float()
42
+ final = torch.stack([mask_float, mask_float, mask_float])
43
+ return to_pil_image(final)
requirements.txt ADDED
File without changes