ariG23498 commited on
Commit
755aa6f
1 Parent(s): 2955389
Files changed (2) hide show
  1. app.py +191 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AutoPipelineForInpainting
4
+ from PIL import Image
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BlipForConditionalGeneration,
9
+ BlipProcessor,
10
+ OwlViTForObjectDetection,
11
+ OwlViTProcessor,
12
+ SamModel,
13
+ SamProcessor,
14
+ )
15
+
16
+
17
+ def delete_model(model):
18
+ model.to("cpu")
19
+ del model
20
+ torch.cuda.empty_cache()
21
+
22
+
23
+ def run_language_model(edit_prompt, device):
24
+ language_model_id = "Qwen/Qwen1.5-0.5B-Chat"
25
+ language_model = AutoModelForCausalLM.from_pretrained(
26
+ language_model_id, device_map="auto"
27
+ )
28
+ tokenizer = AutoTokenizer.from_pretrained(language_model_id)
29
+ messages = [
30
+ {
31
+ "role": "system",
32
+ "content": "Follow the examples and return the expected output",
33
+ },
34
+ {"role": "user", "content": "swap mountain and lion"}, # example 1
35
+ {"role": "assistant", "content": "mountain, lion"}, # example 1
36
+ {"role": "user", "content": "change the dog with cat"}, # example 2
37
+ {"role": "assistant", "content": "dog, cat"}, # example 2
38
+ {"role": "user", "content": "replace the human with a boat"}, # example 3
39
+ {"role": "assistant", "content": "human, boat"}, # example 3
40
+ {"role": "user", "content": edit_prompt},
41
+ ]
42
+ text = tokenizer.apply_chat_template(
43
+ messages, tokenize=False, add_generation_prompt=True
44
+ )
45
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
46
+ generated_ids = language_model.generate(model_inputs.input_ids, max_new_tokens=512)
47
+ generated_ids = [
48
+ output_ids[len(input_ids) :]
49
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
50
+ ]
51
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
+ to_replace, replace_with = response.split(", ")
53
+
54
+ delete_model(language_model)
55
+ return (to_replace, replace_with)
56
+
57
+
58
+ def run_image_captioner(image, device):
59
+ caption_model_id = "Salesforce/blip-image-captioning-base"
60
+ caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to(
61
+ device
62
+ )
63
+ caption_processor = BlipProcessor.from_pretrained(caption_model_id)
64
+ inputs = caption_processor(image, return_tensors="pt").to(device)
65
+ with torch.no_grad():
66
+ outputs = caption_model.generate(**inputs, max_new_tokens=200)
67
+ caption = caption_processor.decode(outputs[0], skip_special_tokens=True)
68
+
69
+ delete_model(caption_model)
70
+ return caption
71
+
72
+
73
+ def run_segmentation(image, object_to_segment, device):
74
+ # OWL-ViT for object detection
75
+ owl_vit_model_id = "google/owlvit-base-patch32"
76
+ processor = OwlViTProcessor.from_pretrained(owl_vit_model_id)
77
+ od_model = OwlViTForObjectDetection.from_pretrained(owl_vit_model_id).to(device)
78
+ text_queries = [object_to_segment]
79
+ inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
80
+ with torch.no_grad():
81
+ outputs = od_model(**inputs)
82
+ target_sizes = torch.tensor([image.size]).to(device)
83
+ results = processor.post_process_object_detection(
84
+ outputs, threshold=0.1, target_sizes=target_sizes
85
+ )[0]
86
+
87
+ boxes = results["boxes"].tolist()
88
+
89
+ delete_model(od_model)
90
+
91
+ # SAM for image segmentation
92
+ sam_model_id = "facebook/sam-vit-base"
93
+ seg_model = SamModel.from_pretrained(sam_model_id).to(device)
94
+ processor = SamProcessor.from_pretrained(sam_model_id)
95
+ input_boxes = [boxes]
96
+ inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device)
97
+ with torch.no_grad():
98
+ outputs = seg_model(**inputs)
99
+ masks = processor.image_processor.post_process_masks(
100
+ outputs.pred_masks.cpu(),
101
+ inputs["original_sizes"].cpu(),
102
+ inputs["reshaped_input_sizes"].cpu(),
103
+ )
104
+
105
+ delete_model(seg_model)
106
+ return masks
107
+
108
+
109
+ def run_inpainting(image, replaced_caption, masks, device):
110
+ pipeline = AutoPipelineForInpainting.from_pretrained(
111
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
112
+ torch_dtype=torch.float16,
113
+ variant="fp16",
114
+ ).to(device)
115
+
116
+ prompt = replaced_caption
117
+ negative_prompt = """lowres, bad anatomy, bad hands,
118
+ text, error, missing fingers, extra digit, fewer digits,
119
+ cropped, worst quality, low quality"""
120
+
121
+ output = pipeline(
122
+ prompt=prompt,
123
+ image=image,
124
+ mask_image=Image.fromarray(masks[0][0][0, :, :].numpy()),
125
+ negative_prompt=negative_prompt,
126
+ guidance_scale=7.5,
127
+ strength=0.6,
128
+ ).images[0]
129
+
130
+ delete_model(pipeline)
131
+ return output
132
+
133
+
134
+ def run_open_gen_fill(image, edit_prompt):
135
+ device = "cuda" if torch.cuda.is_available() else "cpu"
136
+
137
+ # Resize the image to (512, 512)
138
+ image = image.resize((512, 512))
139
+
140
+ # Run the langauge model to extract the objects to be swapped from
141
+ # the edit prompt
142
+ to_replace, replace_with = run_language_model(
143
+ edit_prompt=edit_prompt, device=device
144
+ )
145
+
146
+ # Caption the input image
147
+ caption = run_image_captioner(image, device=device)
148
+
149
+ # Replace the object in the caption with the new object
150
+ replaced_caption = caption.replace(to_replace, replace_with)
151
+
152
+ # Segment the `to_replace` object from the input image
153
+ masks = run_segmentation(image, to_replace, device=device)
154
+
155
+ # Diffusion pipeline for inpainting
156
+ return run_inpainting(
157
+ image=image, replaced_caption=replaced_caption, masks=masks, device=device
158
+ )
159
+
160
+
161
+ def setup_gradio_interface():
162
+ block = gr.Blocks()
163
+
164
+ with block:
165
+ gr.Markdown("<h1><center>Open Generative Fill<h1><center>")
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ input_image_placeholder = gr.Image(type="pil", label="Input Image")
170
+ edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt")
171
+ run_button_placeholder = gr.Button(value="Run")
172
+
173
+ with gr.Column():
174
+ output_image_placeholder = gr.Image(type="pil", label="Output Image")
175
+
176
+ run_button_placeholder.click(
177
+ fn=lambda image, edit_prompt: run_open_gen_fill(
178
+ image=image,
179
+ edit_prompt=edit_prompt,
180
+ ),
181
+ inputs=[input_image_placeholder, edit_prompt_placeholder],
182
+ outputs=[output_image_placeholder],
183
+ )
184
+
185
+ return block
186
+
187
+
188
+ if __name__ == "__main__":
189
+ gradio_interface = setup_gradio_interface()
190
+ gradio_interface.queue(max_size=5)
191
+ gradio_interface.launch(share=False, show_api=False, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.18.0
2
+ accelerate==0.27.0
3
+ diffusers==0.26.2
4
+ transformers==4.37.2