abyildirim commited on
Commit
94a0cd2
1 Parent(s): 005f2dd

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ import constants
7
+ import utils
8
+
9
+ PREDICTOR = None
10
+
11
+
12
+ def inference(image: np.ndarray, text: str, center_crop: bool):
13
+ num_steps = 10
14
+ if not text.lower().startswith("remove the"):
15
+ raise gr.Error("Instruction should start with 'Remove the' !")
16
+
17
+ image = Image.fromarray(image)
18
+ cropped_image, image = utils.preprocess_image(image, center_crop=center_crop)
19
+
20
+ utils.seed_everything()
21
+ prediction = PREDICTOR.predict(image, text, num_steps)
22
+
23
+ print("Num steps:", num_steps)
24
+
25
+ return cropped_image, prediction
26
+
27
+
28
+ if __name__ == "__main__":
29
+ utils.setup_environment()
30
+
31
+ if not PREDICTOR:
32
+ PREDICTOR = utils.get_predictor()
33
+
34
+ sample_image, sample_instruction, sample_step = constants.EXAMPLES[3]
35
+
36
+ gr.Interface(
37
+ fn=inference,
38
+ inputs=[
39
+ gr.Image(type="numpy", value=sample_image, label="Source Image").style(
40
+ height=256
41
+ ),
42
+ gr.Textbox(
43
+ label="Instruction",
44
+ lines=1,
45
+ value=sample_instruction,
46
+ ),
47
+ gr.Checkbox(value=True, label="Center Crop", interactive=False),
48
+ ],
49
+ outputs=[
50
+ gr.Image(type="pil", label="Cropped Image").style(height=256),
51
+ gr.Image(type="pil", label="Output Image").style(height=256),
52
+ ],
53
+ allow_flagging="never",
54
+ examples=constants.EXAMPLES,
55
+ cache_examples=True,
56
+ title=constants.TITLE,
57
+ description=constants.DESCRIPTION,
58
+ ).launch()
constants.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TITLE = "Inst-Inpaint: Instructing to Remove Objects with Diffusion Models"
2
+
3
+ DESCRIPTION = """
4
+ <p style='text-align: center'>
5
+ <a href='http://instinpaint.abyildirim.com' target='_blank'>Project Page</a> |
6
+ <a href='https://arxiv.org/abs/2304.03246' target='_blank'>Paper</a> |
7
+ <a href='https://github.com/abyildirim/inst-inpaint' target='_blank'>GitHub Repo</a> |
8
+ </p>
9
+ <p style='text-align: center'>
10
+ This demo demonstrates the Inst-Inpaint's abilities for instruction-based image inpainting.
11
+ </p>
12
+ """
13
+
14
+ EXAMPLES = [
15
+ ["examples/kite-boy.png", "Remove the colorful kite", True],
16
+ ["examples/cat-car.jpg", "Remove the car", True],
17
+ ["examples/bus-tree.jpg", "Remove the bus", True],
18
+ ["examples/cups.webp", "Remove the cup at the left", True],
19
+ ["examples/woman-fantasy.jpg", "Remove the woman", True],
20
+ ["examples/clock.png", "Remove the round clock at the center", True],
21
+ ["examples/woman.png", "Remove the woman at the left", True],
22
+ ["examples/men.png", "Remove the man at the right", True],
23
+ ["examples/tree.png", "Remove the tree", True],
24
+ ["examples/birds.png", "Remove the bird at the right of the bird", True]
25
+ ]
examples/birds.png ADDED
examples/bus-tree.jpg ADDED
examples/cat-car.jpg ADDED
examples/clock.png ADDED
examples/cups.webp ADDED
examples/kite-boy.png ADDED
examples/men.png ADDED
examples/tree.png ADDED
examples/woman-fantasy.jpg ADDED
examples/woman.png ADDED
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ git+https://github.com/openai/CLIP.git
3
+ torch==1.13.1+cpu
4
+ torchvision==0.14.1+cpu
5
+ pytorch-lightning==1.6.5
6
+ taming-transformers-rom1504==0.0.6
7
+ einops==0.6.0
8
+ kornia==0.6.11
9
+ transformers==4.27.4
10
+ dill==0.3.6
11
+ gradio==3.24.1
12
+ gdown==4.7.1
13
+ torchmetrics==0.11.4
utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import tarfile
5
+ from typing import Tuple
6
+
7
+ import dill
8
+ import gdown
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision.transforms import ToTensor
13
+
14
+ logger = logging.getLogger(__file__)
15
+
16
+ to_tensor = ToTensor()
17
+
18
+
19
+ def preprocess_image(
20
+ image: Image, resize_shape: Tuple[int, int] = (256, 256), center_crop=True
21
+ ):
22
+ processed_image = image
23
+
24
+ if center_crop:
25
+ width, height = image.size
26
+ crop_size = min(width, height)
27
+
28
+ left = (width - crop_size) // 2
29
+ top = (height - crop_size) // 2
30
+ right = (width + crop_size) // 2
31
+ bottom = (height + crop_size) // 2
32
+
33
+ processed_image = image.crop((left, top, right, bottom))
34
+
35
+ processed_image = processed_image.resize(resize_shape)
36
+
37
+ image = to_tensor(processed_image)
38
+ image = image.unsqueeze(0) * 2 - 1
39
+
40
+ return processed_image, image
41
+
42
+
43
+ def download_artifacts(output_path: str):
44
+ logger.error("Downloading the model artifacts...")
45
+ if not os.path.exists(output_path):
46
+ gdown.download(id=os.environ["GDRIVE_ID"], output=output_path, quiet=True)
47
+
48
+
49
+ def extract_artifacts(path: str):
50
+ logger.error("Extracting the model artifacts...")
51
+ if not os.path.exists("model.pkl"):
52
+ with tarfile.open(path) as tar:
53
+ tar.extractall()
54
+
55
+
56
+ def setup_environment():
57
+ os.environ["PYTHONPATH"] = os.getcwd()
58
+
59
+ artifacts_path = "artifacts.tar.gz"
60
+
61
+ download_artifacts(output_path=artifacts_path)
62
+
63
+ extract_artifacts(path=artifacts_path)
64
+
65
+
66
+ def get_predictor():
67
+ logger.error("Loading the predictor...")
68
+ with open("model.pkl", "rb") as fp:
69
+ return dill.load(fp)
70
+
71
+
72
+ def seed_everything(seed: int = 0):
73
+ random.seed(seed)
74
+ os.environ["PYTHONHASHSEED"] = str(seed)
75
+ np.random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ torch.cuda.manual_seed(seed)
78
+ torch.backends.cudnn.deterministic = True