RunningYou commited on
Commit
ffab2fd
1 Parent(s): b3556d1

inpainting

Browse files
.idea/$CACHE_FILE$ ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectInspectionProfilesVisibleTreeState">
4
+ <entry key="Project Default">
5
+ <profile-state>
6
+ <expanded-state>
7
+ <State>
8
+ <id />
9
+ </State>
10
+ <State>
11
+ <id>General</id>
12
+ </State>
13
+ </expanded-state>
14
+ <selected-state>
15
+ <State>
16
+ <id>AngularJS</id>
17
+ </State>
18
+ </selected-state>
19
+ </profile-state>
20
+ </entry>
21
+ </component>
22
+ </project>
.idea/dictionaries ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectDictionaryState">
4
+ <dictionary name="running_you" />
5
+ </component>
6
+ </project>
.idea/encodings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Encoding" defaultCharsetForPropertiesFiles="UTF-8">
4
+ <file url="PROJECT" charset="UTF-8" />
5
+ </component>
6
+ </project>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="PROJECT_PROFILE" />
4
+ </settings>
5
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="JavaScriptSettings">
4
+ <option name="languageLevel" value="ES6" />
5
+ </component>
6
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/new_inpainting.iml" filepath="$PROJECT_DIR$/.idea/new_inpainting.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/new_inpainting.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.7 (modelscope)" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/workspace.xml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectId" id="2Id7cZROeBLwSmWyYcqsSXcYO5h" />
4
+ <component name="PropertiesComponent">
5
+ <property name="settings.editor.selected.configurable" value="editor.preferences.import" />
6
+ </component>
7
+ <component name="PyConsoleOptionsProvider">
8
+ <option name="myPythonConsoleState">
9
+ <console-settings sdk-home="$USER_HOME$/anaconda3/bin/python" working-directory="$PROJECT_DIR$/..">
10
+ <option name="mySdkHome" value="$USER_HOME$/anaconda3/bin/python" />
11
+ <option name="myWorkingDirectory" value="$PROJECT_DIR$/.." />
12
+ </console-settings>
13
+ </option>
14
+ </component>
15
+ <component name="VcsContentAnnotationSettings">
16
+ <option name="myLimit" value="2678400000" />
17
+ </component>
18
+ </project>
app.py CHANGED
@@ -1,8 +1,94 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ from PIL import Image
3
+ import math
4
+ import streamlit as st
5
+ import numpy as np
6
+ import torch
7
+ import PIL
8
+ import cv2
9
+ import mediapipe as mp
10
  import gradio as gr
11
+ from diffusers import StableDiffusionInpaintPipeline
12
 
13
+ YOUR_TOKEN = st.secrets['USER_TOKEN']
 
14
 
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ model_path = "runwayml/stable-diffusion-inpainting"
17
 
18
+ if torch.cuda.is_available():
19
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
20
+ use_auth_token=YOUR_TOKEN).to(device)
21
+ else:
22
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)
23
+
24
+
25
+ def image_grid(imgs, rows, cols):
26
+ assert len(imgs) == rows * cols
27
+
28
+ w, h = imgs[0].size
29
+ grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
30
+ grid_w, grid_h = grid.size
31
+
32
+ for i, img in enumerate(imgs):
33
+ grid.paste(img, box=(i % cols * w, i // cols * h))
34
+ return grid
35
+
36
+
37
+ def mediapipe_segmentation(image_file, mask_file):
38
+ mp_drawing = mp.solutions.drawing_utils
39
+ mp_selfie_segmentation = mp.solutions.selfie_segmentation
40
+
41
+ # For static images:
42
+ BG_COLOR = (0, 0, 0) # gray
43
+ MASK_COLOR = (255, 255, 255) # white
44
+ with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
45
+ image = cv2.imread(image_file)
46
+ image_height, image_width, _ = image.shape
47
+ # Convert the BGR image to RGB before processing.
48
+ results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
49
+
50
+ # blurred_image = cv2.GaussianBlur(image,(55,55),0)
51
+ # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
52
+ # output_image = np.where(condition, image, blurred_image)
53
+
54
+ # Draw selfie segmentation on the background image.
55
+ # To improve segmentation around boundaries, consider applying a joint
56
+ # bilateral filter to "results.segmentation_mask" with "image".
57
+ condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
58
+ # Generate solid color images for showing the output selfie segmentation mask.
59
+ fg_image = np.zeros(image.shape, dtype=np.uint8)
60
+ fg_image[:] = MASK_COLOR
61
+ bg_image = np.zeros(image.shape, dtype=np.uint8)
62
+ bg_image[:] = BG_COLOR
63
+ output_image = np.where(condition, fg_image, bg_image)
64
+ cv2.imwrite(mask_file, output_image)
65
+
66
+
67
+ def image_inpainting(prompt, image_path, mask_image_path, num_samples=4):
68
+ image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
69
+ mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
70
+
71
+ guidance_scale = 7.5
72
+ generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
73
+
74
+ images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, generator=generator,
75
+ num_images_per_prompt=num_samples).images
76
+
77
+ # insert initial image in the list so we can compare side by side
78
+ # images.insert(0, image)
79
+ return image_grid(images, 2, math.ceil(num_samples/2))
80
+
81
+
82
+ def predict1(dict, prompt):
83
+ dict['image'].save('image.png')
84
+ dict['mask'].save('mask.png')
85
+ mediapipe_segmentation('image.png', 'm_mask.png')
86
+
87
+ image = image_inpainting(prompt, image_path='image.png', mask_image_path='m_mask.png')
88
+ return image
89
+
90
+
91
+ title = "Person Matting & Stable Diffusion In-Painting"
92
+ description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"
93
+ gr.Interface(predict1, inputs=[gr.Image(source='upload', tool='sketch', type='pil'), gr.Textbox(label='prompt')],
94
+ outputs='image', title=title, description=description).launch(max_threads=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.6.0
2
+ transformers
3
+ mediapipe
4
+ scipy
5
+ ftfy
6
+ imageio
7
+ torch
8
+ streamlit