sandrawang1031 commited on
Commit
eca813c
1 Parent(s): dce3dbb
Files changed (4) hide show
  1. .gitignore +138 -0
  2. app.py +45 -0
  3. model.py +147 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+ .docker/
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pipenv
86
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
88
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
89
+ # install all needed dependencies.
90
+ #Pipfile.lock
91
+
92
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
93
+ __pypackages__/
94
+
95
+ # Celery stuff
96
+ celerybeat-schedule
97
+ celerybeat.pid
98
+
99
+ # SageMath parsed files
100
+ *.sage.py
101
+
102
+ # Environments
103
+ .env
104
+ # direnv
105
+ .envrc
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+ docker-compose-interpreter-local.yml
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ # ide
133
+ .idea/
134
+ .vscode/
135
+
136
+ # macos
137
+ .DS_Store
138
+ .envrc
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from model import VirtualStagingToolV2
5
+
6
+
7
+ def predict(image, style, color_preference):
8
+ init_image = image.convert("RGB").resize((512, 512))
9
+ # mask = dict["mask"].convert("RGB").resize((512, 512))
10
+
11
+ vs_tool = VirtualStagingToolV2(diffusion_version="stabilityai/stable-diffusion-2-inpainting")
12
+ output_images, transparent_mask_image = vs_tool.virtual_stage(
13
+ image=init_image, style=style, color_preference=color_preference, number_images=1)
14
+ return output_images[0], transparent_mask_image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
15
+
16
+
17
+ image_blocks = gr.Blocks()
18
+ with image_blocks as demo:
19
+ with gr.Group():
20
+ with gr.Box():
21
+ with gr.Row():
22
+ with gr.Column():
23
+ image = gr.Image(source='upload', elem_id="image_upload",
24
+ type="pil", label="Upload",
25
+ ).style(height=400)
26
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
27
+ style = gr.Dropdown(
28
+ ["Mordern", "Coastal", "French country"],
29
+ label="Design theme", elem_id="input-color"
30
+ )
31
+
32
+ color_preference = gr.Textbox(placeholder='Enter color preference',
33
+ label="Color preference", elem_id="input-color")
34
+ btn = gr.Button("Inpaint!").style(
35
+ margin=False,
36
+ rounded=(False, True, True, False),
37
+ full_width=False,
38
+ )
39
+ with gr.Column():
40
+ mask_image = gr.Image(label="Mask image", elem_id="mask-img").style(height=400)
41
+ image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
42
+
43
+ btn.click(fn=predict, inputs=[image, style, color_preference], outputs=[image_out, mask_image])
44
+
45
+ image_blocks.launch()
model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as mpatches
4
+ from matplotlib import cm
5
+
6
+ from PIL import Image
7
+
8
+ import torch
9
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
10
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
11
+ from diffusers import StableDiffusionInpaintPipeline
12
+
13
+
14
+ class VirtualStagingToolV2():
15
+
16
+ def __init__(self,
17
+ segmentation_version='openmmlab/upernet-convnext-tiny',
18
+ diffusion_version="stabilityai/stable-diffusion-2-inpainting"
19
+ ):
20
+
21
+ self.segmentation_version = segmentation_version
22
+ self.diffusion_version = diffusion_version
23
+
24
+ self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
25
+ self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
26
+
27
+ self.diffution_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
28
+ self.diffusion_version,
29
+ torch_dtype=torch.float32,
30
+ )
31
+ self.diffution_pipeline = self.diffution_pipeline.to("cpu")
32
+
33
+ def _predict(self, image):
34
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
35
+ outputs = self.segmentation_model(**inputs)
36
+ prediction = \
37
+ self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
38
+ return prediction
39
+
40
+ def _save_mask(self, img, prediction_array, mask_items=[]):
41
+ mask = np.zeros_like(prediction_array, dtype=np.uint8)
42
+
43
+ mask[np.isin(prediction_array, mask_items)] = 0
44
+ mask[~np.isin(prediction_array, mask_items)] = 255
45
+
46
+ # # # Create a PIL Image object from the mask
47
+ mask_image = Image.fromarray(mask, mode='L')
48
+ # display(mask_image)
49
+
50
+ # mask_image = mask_image.resize((512, 512))
51
+ # mask_image.save(".tmp/mask_1.png", "PNG")
52
+ # img = img.resize((512, 512))
53
+ # img.save(".tmp/input_1.png", "PNG")
54
+ return mask_image
55
+
56
+ def _save_transparent_mask(self, img, prediction_array, mask_items=[]):
57
+ mask = np.array(img)
58
+ mask[~np.isin(prediction_array, mask_items), :] = 255
59
+ mask_image = Image.fromarray(mask).convert('RGBA')
60
+
61
+ # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
62
+ mask_data = mask_image.getdata()
63
+ mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
64
+ mask_image.putdata(mask_data)
65
+
66
+ return mask_image
67
+
68
+ def get_mask(self, image_path=None, image=None):
69
+ if image_path:
70
+ image = Image.open(image_path)
71
+ else:
72
+ if not image:
73
+ raise ValueError("no image provided")
74
+
75
+ # display(image)
76
+ prediction = self._predict(image)
77
+
78
+ label_ids = np.unique(prediction)
79
+
80
+ mask_items = [0, 3, 5, 8, 14]
81
+
82
+ if 1 in label_ids or 25 in label_ids:
83
+ mask_items = [1, 2, 4, 25, 32]
84
+ room = 'backyard'
85
+ elif 73 in label_ids or 50 in label_ids or 61 in label_ids:
86
+ mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 118, 124, 129
87
+ ]
88
+ room = 'kitchen'
89
+ elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
90
+ mask_items = [0, 3, 5, 8, 14, 27, 65]
91
+ room = 'bathroom'
92
+ elif 7 in label_ids:
93
+ room = 'bedroom'
94
+ elif 23 in label_ids or 49 in label_ids:
95
+ room = 'living room'
96
+
97
+ label_ids_without_mask = [i for i in label_ids if i not in mask_items]
98
+
99
+ items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]
100
+
101
+ mask_image = self._save_mask(image, prediction, mask_items)
102
+ transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
103
+ return mask_image, transparent_mask_image, image, items, room
104
+
105
+ def _edit_image(self, init_image, mask_image, prompt, # height, width,
106
+ number_images=1):
107
+
108
+ init_image = init_image.resize((512, 512)).convert("RGB")
109
+ mask_image = mask_image.resize((512, 512)).convert("RGB")
110
+
111
+ display(init_image)
112
+ display(mask_image)
113
+
114
+ output_images = self.diffution_pipeline(
115
+ prompt=prompt, image=init_image, mask_image=mask_image,
116
+ # width=width, height=height,
117
+ num_images_per_prompt=number_images).images
118
+ # display(output_image)
119
+ return output_images
120
+
121
+ def virtual_stage(self, image_path=None, image=None, style=None, color_preference=None, number_images=1):
122
+ mask_image, transparent_mask_image, init_image, items, room = self.get_mask(image_path, image)
123
+ if not style:
124
+ raise ValueError('style not provided.')
125
+ if not color_preference:
126
+ raise ValueError('color_preference not provided.')
127
+
128
+ if room == 'kitchen':
129
+ items = [i for i in items if i in ['kitchen island', 'cabinet', 'shelf', 'counter', 'countertop', 'stool']]
130
+ elif room == 'bedroom':
131
+ items = [i for i in items if i in ['bed', 'table', 'chest of drawers', 'desk', 'armchair', 'wardrobe']]
132
+ elif room == 'bathroom':
133
+ items = [i for i in items if
134
+ i in ['shower', 'bathtub', 'chest of drawers', 'counter', 'countertop', 'sink']]
135
+
136
+ items = ', '.join(items)
137
+ prompt = f'{items}, high resolution, in the {style} style {room} in {color_preference}'
138
+ print(prompt)
139
+
140
+ output_images = self._edit_image(init_image, mask_image, prompt, number_images)
141
+
142
+ final_output_images = []
143
+ for output_image in output_images:
144
+ display(output_image)
145
+ output_image = output_image.resize(init_image.size)
146
+ final_output_images.append(output_image)
147
+ return final_output_images, transparent_mask_image
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.29.0
2
+ torch==1.11.0
3
+ diffusers==0.16.1
4
+ accelerate==0.19.0
5
+ matplotlib==3.6.2
6
+ pillow==9.2.0