File size: 10,746 Bytes
71a7048
755aa6f
 
 
f523ad6
755aa6f
 
 
 
 
b20a406
 
755aa6f
 
 
 
 
 
 
 
 
 
5d23e43
f523ad6
755aa6f
 
 
 
 
 
bb9d17d
f523ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755aa6f
 
bb9d17d
 
 
755aa6f
 
bb9d17d
 
 
 
 
 
 
 
 
 
 
755aa6f
f523ad6
 
 
7175b9b
755aa6f
 
7175b9b
755aa6f
5d23e43
755aa6f
 
 
 
 
 
 
 
 
 
 
 
 
 
5d23e43
755aa6f
5fa34bb
 
14e72bd
 
755aa6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3d52b
 
 
 
755aa6f
 
 
5d23e43
f523ad6
755aa6f
 
 
 
 
 
f523ad6
f153f53
f523ad6
755aa6f
 
 
 
 
 
 
 
f153f53
 
755aa6f
 
3d3d52b
f523ad6
755aa6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f523ad6
 
7175b9b
f523ad6
 
755aa6f
 
 
 
 
f523ad6
39354f2
f523ad6
755aa6f
 
39354f2
 
 
 
3d3d52b
39354f2
 
 
755aa6f
 
 
 
 
6239abf
755aa6f
 
 
 
 
 
 
 
e62a27c
39354f2
 
 
755aa6f
 
 
 
 
 
 
 
39354f2
 
 
 
 
 
 
755aa6f
 
5b94f14
 
 
39354f2
 
 
 
 
 
 
5b94f14
 
 
 
 
 
 
 
755aa6f
 
 
 
 
0ad1f77
755aa6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import spaces
import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image, ImageFilter
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BlipForConditionalGeneration,
    BlipProcessor,
    Owlv2ForObjectDetection,
    Owlv2Processor,
    SamModel,
    SamProcessor,
)


def delete_model(model):
    model.to("cpu")
    del model
    torch.cuda.empty_cache()

@spaces.GPU()
def run_language_model(edit_prompt, caption, device):
    language_model_id = "Qwen/Qwen1.5-0.5B-Chat"
    language_model = AutoModelForCausalLM.from_pretrained(
        language_model_id, device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(language_model_id)
    messages = [
        {"role": "system", "content": "Follow the examples and return the expected output"},
        {"role": "user", "content": "Caption: a blue sky with fluffy clouds\nQuery: Make the sky stormy"},
        {"role": "assistant", "content": "A: sky\nB: a stormy sky with heavy gray clouds, torrential rain, gloomy, overcast"},
        {"role": "user", "content": "Caption: a cat sleeping on a sofa\nQuery: Change the cat to a dog"},
        {"role": "assistant", "content": "A: cat\nB: a dog sleeping on a sofa, cozy and comfortable, snuggled up in a warm blanket, peaceful"},
        {"role": "user", "content": "Caption: a snowy mountain peak\nQuery: Replace the snow with greenery"},
        {"role": "assistant", "content": "A: snow\nB: a lush green mountain peak in summer, clear blue skies, birds flying overhead, serene and majestic"},
        {"role": "user", "content": "Caption: a vintage car parked by the roadside\nQuery: Change the car to a modern electric vehicle"},
        {"role": "assistant", "content": "A: car\nB: a sleek modern electric vehicle parked by the roadside, cutting-edge design, environmentally friendly, silent and powerful"},
        {"role": "user", "content": "Caption: a wooden bridge over a river\nQuery: Make the bridge stone"},
        {"role": "assistant", "content": "A: bridge\nB: an ancient stone bridge over a river, moss-covered, sturdy and timeless, with clear waters flowing beneath"},
        {"role": "user", "content": "Caption: a bowl of salad on the table\nQuery: Replace salad with soup"},
        {"role": "assistant", "content": "A: bowl\nB: a bowl of steaming hot soup on the table, scrumptious, with garnishing"},
        {"role": "user", "content": "Caption: a book on a desk surrounded by stationery\nQuery: Remove all stationery, add a laptop"},
        {"role": "assistant", "content": "A: stationery\nB: a book on a desk with a laptop next to it, modern study setup, focused and productive, technology and education combined"},
        {"role": "user", "content": "Caption: a cup of coffee on a wooden table\nQuery: Change coffee to tea"},
        {"role": "assistant", "content": "A: cup\nB: a steaming cup of tea on a wooden table, calming and aromatic, with a slice of lemon on the side, inviting"},
        {"role": "user", "content": "Caption: a small pen on a white table\nQuery: Change the pen to an elaborate fountain pen"},
        {"role": "assistant", "content": "A: pen\nB: an elaborate fountain pen on a white table, sleek and elegant, with intricate designs, ready for writing"},
        {"role": "user", "content": "Caption: a plain notebook on a desk\nQuery: Replace the notebook with a journal"},
        {"role": "assistant", "content": "A: notebook\nB: an artistically decorated journal on a desk, vibrant cover, filled with creativity, inspiring and personalized"},
        {"role": "user", "content": f"Caption: {caption}\nQuery: {edit_prompt}"},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = language_model.generate(
          model_inputs.input_ids,
          max_new_tokens=512,
          temperature=0.0,
          do_sample=False
        )
        
        generated_ids = [
          output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    output_generation_a, output_generation_b = response.split("\n")
    to_replace = output_generation_a[2:].strip()
    replaced_caption = output_generation_b[2:].strip()

    delete_model(language_model)
    return (to_replace, replaced_caption)

@spaces.GPU()
def run_image_captioner(image, device):
    caption_model_id = "Salesforce/blip-image-captioning-base"
    caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to(
        device
    )
    caption_processor = BlipProcessor.from_pretrained(caption_model_id)
    inputs = caption_processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = caption_model.generate(**inputs, max_new_tokens=200)
    caption = caption_processor.decode(outputs[0], skip_special_tokens=True)

    delete_model(caption_model)
    return caption

@spaces.GPU()
def run_segmentation(image, object_to_segment, device):
    # OWL-V2 for object detection
    owl_v2_model_id = "google/owlv2-base-patch16-ensemble"
    processor = Owlv2Processor.from_pretrained(owl_v2_model_id)
    od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device)
    text_queries = [object_to_segment]
    inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = od_model(**inputs)
        target_sizes = torch.tensor([image.size]).to(device)
        results = processor.post_process_object_detection(
            outputs, threshold=0.1, target_sizes=target_sizes
        )[0]

    boxes = results["boxes"].tolist()

    delete_model(od_model)

    # SAM for image segmentation
    sam_model_id = "facebook/sam-vit-base"
    seg_model = SamModel.from_pretrained(sam_model_id).to(device)
    processor = SamProcessor.from_pretrained(sam_model_id)
    input_boxes = [boxes]
    inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = seg_model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu(),
    )[0]
    # Merge the masks
    masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values
    
    delete_model(seg_model)
    return masks

@spaces.GPU()
def run_inpainting(image, replaced_caption, masks, generator, device):
    pipeline = AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
        torch_dtype=torch.float16,
        variant="fp16",
    ).to(device)

    masks = Image.fromarray(masks.numpy())
    # dilation_image = masks.filter(ImageFilter.MaxFilter(3))

    prompt = replaced_caption
    negative_prompt = """lowres, bad anatomy, bad hands,
    text, error, missing fingers, extra digit, fewer digits,
    cropped, worst quality, low quality"""

    output = pipeline(
        prompt=prompt,
        image=image,
        # mask_image=dilation_image,
        mask_image=masks,
        negative_prompt=negative_prompt,
        guidance_scale=7.5,
        strength=1.0,
        generator=generator,
    ).images[0]

    delete_model(pipeline)
    return output


def run_open_gen_fill(image, edit_prompt):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Resize the image to (512, 512)
    image = image.resize((512, 512))

    # Caption the input image
    caption = run_image_captioner(image, device=device)

    # Run the langauge model to extract the object for segmentation
    # and get the replaced caption
    to_replace, replaced_caption = run_language_model(
        edit_prompt=edit_prompt, caption=caption, device=device
    )

    # Segment the `to_replace` object from the input image
    masks = run_segmentation(image, to_replace, device=device)

    # Diffusion pipeline for inpainting
    generator = torch.Generator(device).manual_seed(17)
    output = run_inpainting(
        image=image, replaced_caption=replaced_caption, masks=masks, generator=generator, device=device
    )

    return (
        to_replace,
        caption,
        replaced_caption,
        Image.fromarray(masks.numpy()),
        output,
    )


def setup_gradio_interface():
    block = gr.Blocks()

    with block:
        gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>")

        with gr.Row():
            with gr.Column():
                input_image_placeholder = gr.Image(type="pil", label="Input Image")
                edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt")
                run_button_placeholder = gr.Button(value="Run")

            with gr.Column():
                to_replace_placeholder = gr.Textbox(label="to_replace")
                image_caption_placeholder = gr.Textbox(label="Image Caption")
                replace_caption_placeholder = gr.Textbox(label="Replaced Caption")
                segmentation_placeholder = gr.Image(type="pil", label="Segmentation")
                output_image_placeholder = gr.Image(type="pil", label="Output Image")

        run_button_placeholder.click(
            fn=lambda image, edit_prompt: run_open_gen_fill(
                image=image,
                edit_prompt=edit_prompt,
            ),
            inputs=[input_image_placeholder, edit_prompt_placeholder],
            outputs=[
                to_replace_placeholder,
                image_caption_placeholder,
                replace_caption_placeholder,
                segmentation_placeholder,
                output_image_placeholder,
            ],
        )

        gr.Examples(
            examples=[["dog.jpeg", "replace the dog with a tiger"]],
            inputs=[input_image_placeholder, edit_prompt_placeholder],
            outputs=[
                to_replace_placeholder,
                image_caption_placeholder,
                replace_caption_placeholder,
                segmentation_placeholder,
                output_image_placeholder,
            ],
            fn=lambda image, edit_prompt: run_open_gen_fill(
                image=image,
                edit_prompt=edit_prompt,
            ),
            cache_examples=True,
            label="Try this example input!",
        )

    return block


if __name__ == "__main__":
    gradio_interface = setup_gradio_interface()
    # gradio_interface.queue(max_size=10)
    gradio_interface.launch(share=False, show_api=False, show_error=True)