ameerazam08 commited on
Commit
6dc8351
1 Parent(s): 987bde8

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/comparison.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/example1.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/example2.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/example3.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/subtraction.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/tree.png filter=lfs diff=lfs merge=lfs -text
42
+ flagged/Style[[:space:]]Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png filter=lfs diff=lfs merge=lfs -text
43
+ result.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,131 @@
1
- ---
2
- title: InstantStyle GPU Demo
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 4.25.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
3
+
4
+ [**Haofan Wang**](https://haofanwang.github.io/)<sup>*</sup> · [Matteo Spinelli](https://github.com/cubiq) · [**Qixun Wang**](https://github.com/wangqixun) · [**Xu Bai**](https://huggingface.co/baymin0220) · [**Zekui Qin**](https://github.com/ZekuiQin) · [**Anthony Chen**](https://antonioo-c.github.io/)
5
+
6
+ InstantX Team
7
+
8
+ <sup>*</sup>corresponding authors
9
+
10
+ <a href='[https://instantid.github.io/](https://instantstyle.github.io/)'><img src='https://img.shields.io/badge/Project-Page-green'></a>
11
+ <a href='https://arxiv.org/abs/2404.02733'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
12
+ [![GitHub](https://img.shields.io/github/stars/InstantStyle/InstantStyle?style=social)](https://github.com/InstantStyle/InstantStyle)
13
+
14
+ </div>
15
+
16
+ InstantStyle is a general framework that employs two straightforward yet potent techniques for achieving an effective disentanglement of style and content from reference images.
17
+
18
+ <img src='assets/pipe.png'>
19
+
20
+ ## Principle
21
+
22
+ Separating Content from Image. Benefit from the good characterization of CLIP global features, after subtracting the content text fea- tures from the image features, the style and content can be explicitly decoupled. Although simple, this strategy is quite effective in mitigating content leakage.
23
+ <p align="center">
24
+ <img src="assets/subtraction.png">
25
+ </p>
26
+
27
+ Injecting into Style Blocks Only. Empirically, each layer of a deep network captures different semantic information the key observation in our work is that there exists two specific attention layers handling style. Specifically, we find up blocks.0.attentions.1 and down blocks.2.attentions.1 capture style (color, material, atmosphere) and spatial layout (structure, composition) respectively.
28
+ <p align="center">
29
+ <img src="assets/tree.png">
30
+ </p>
31
+
32
+ ## Release
33
+ - [2024/04/03] 🔥 We release the [technical report](https://arxiv.org/abs/2404.02733).
34
+
35
+ ## Download
36
+ Follow [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter?tab=readme-ov-file#download-models) to download pre-trained checkpoints.
37
+
38
+ ## Demos
39
+
40
+ ### Stylized Synthesis
41
+
42
+ <p align="center">
43
+ <img src="assets/example1.png">
44
+ <img src="assets/example2.png">
45
+ </p>
46
+
47
+ ### Image-based Stylized Synthesis
48
+
49
+ <p align="center">
50
+ <img src="assets/example3.png">
51
+ </p>
52
+
53
+ ### Comparison with Previous Works
54
+
55
+ <p align="center">
56
+ <img src="assets/comparison.png">
57
+ </p>
58
+
59
+ ## Usage
60
+
61
+ Our method is fully compatible with [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter). But for feature subtraction, it only works with IP-Adapter using global embeddings.
62
+
63
+ ```python
64
+ import torch
65
+ from diffusers import StableDiffusionXLPipeline
66
+ from PIL import Image
67
+
68
+ from ip_adapter import IPAdapterXL
69
+
70
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
71
+ image_encoder_path = "sdxl_models/image_encoder"
72
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
73
+ device = "cuda"
74
+
75
+ # load SDXL pipeline
76
+ pipe = StableDiffusionXLPipeline.from_pretrained(
77
+ base_model_path,
78
+ torch_dtype=torch.float16,
79
+ add_watermarker=False,
80
+ )
81
+
82
+ # load ip-adapter
83
+ # target_blocks=["blocks"] for original IP-Adapter
84
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
85
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
86
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
87
+
88
+ image = "./assets/0.jpg"
89
+ image = Image.open(image)
90
+ image.resize((512, 512))
91
+
92
+ # generate image variations with only image prompt
93
+ images = ip_model.generate(pil_image=image,
94
+ prompt="a cat, masterpiece, best quality, high quality",
95
+ negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
96
+ scale=1.0,
97
+ guidance_scale=5,
98
+ num_samples=1,
99
+ num_inference_steps=30,
100
+ seed=42,
101
+ #neg_content_prompt="a rabbit",
102
+ #neg_content_scale=0.5,
103
+ )
104
+
105
+ images[0].save("result.png")
106
+ ```
107
+
108
+ We will support diffusers API soon.
109
+
110
+ ## TODO
111
+ - Support in diffusers API.
112
+ - Support InstantID.
113
+
114
+ ## Sponsor Us
115
+ If you find this project useful, you can buy us a coffee via Github Sponsor! We support [Paypal](https://ko-fi.com/instantx) and [WeChat Pay](https://tinyurl.com/instantx-pay).
116
+
117
+ ## Cite
118
+ If you find InstantStyle useful for your research and applications, please cite us using this BibTeX:
119
+
120
+ ```bibtex
121
+ @misc{wang2024instantstyle,
122
+ title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
123
+ author={Haofan Wang and Qixun Wang and Xu Bai and Zekui Qin and Anthony Chen},
124
+ year={2024},
125
+ eprint={2404.02733},
126
+ archivePrefix={arXiv},
127
+ primaryClass={cs.CV}
128
+ }
129
+ ```
130
+
131
+ For any question, please feel free to contact us via haofanwang.ai@gmail.com.
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ donwload_repo_loc= "./models/image_encoder/"
4
+ os.system("pip install -U peft")
5
+ # os.system(f"wget -O {donwload_repo_loc}config.json https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/config.json?download=true")
6
+ # os.system(f"wget -O {donwload_repo_loc}model.safetensors https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors?download=true")
7
+ # os.system(f"wget -O {donwload_repo_loc}pytorch_model.bin https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/pytorch_model.bin?download=true")
8
+
9
+ import space
10
+ import gradio as gr
11
+ import torch
12
+ from diffusers import StableDiffusionXLPipeline
13
+ from PIL import Image
14
+ from ip_adapter import IPAdapterXL
15
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ device = "cuda"
17
+
18
+ image_encoder_path = donwload_repo_loc #"sdxl_models/image_encoder"
19
+ ip_ckpt = "./models/ip-adapter_sdxl.bin"
20
+ # load SDXL pipeline
21
+ pipe = StableDiffusionXLPipeline.from_pretrained(
22
+ base_model_path,
23
+ torch_dtype=torch.float16,
24
+ add_watermarker=False,
25
+ )
26
+
27
+
28
+ # generate image variations with only image prompt
29
+ @spaces.GPU(enable_queue=True)
30
+ def create_image(image_pil,target,prompt,n_prompt,scale, guidance_scale,num_samples,num_inference_steps,seed):
31
+ # load ip-adapter
32
+ if target =="Load original IP-Adapter":
33
+ # target_blocks=["blocks"] for original IP-Adapter
34
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"])
35
+ elif target=="Load only style blocks":
36
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
37
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"])
38
+ elif target == "Load style+layout block":
39
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
40
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"])
41
+
42
+
43
+ image_pil=image_pil.resize((512, 512))
44
+ images = ip_model.generate(pil_image=image_pil,
45
+ prompt=prompt,
46
+ negative_prompt=n_prompt,
47
+ scale=scale,
48
+ guidance_scale=guidance_scale,
49
+ num_samples=num_samples,
50
+ num_inference_steps=num_inference_steps,
51
+ seed=seed,
52
+ #neg_content_prompt="a rabbit",
53
+ #neg_content_scale=0.5,
54
+ )
55
+
56
+ # images[0].save("result.png")
57
+ del ip_model
58
+
59
+ return images
60
+
61
+
62
+ DESCRIPTION = """
63
+ # Res-Adapter :Domain Consistent Resolution Adapter for Diffusion Models
64
+ **Demo by [ameer azam] - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
65
+ This is a demo of https://huggingface.co/jiaxiangc/res-adapter ResAdapter by ByteDance.
66
+ ByteDance provide a demo of [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) with [SDXL-Lightning-Step4](https://huggingface.co/ByteDance/SDXL-Lightning) to expand resolution range from 1024-only to 256~1024.
67
+ """
68
+
69
+ block = gr.Blocks(css="footer {visibility: hidden}").queue()
70
+ with block:
71
+ with gr.Row():
72
+
73
+ with gr.Column():
74
+ gr.Markdown("## <h1 align='center'>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation </h1>")
75
+ gr.Markdown(DESCRIPTION)
76
+ with gr.Tabs():
77
+ with gr.Row():
78
+ with gr.Column():
79
+ image_pil = gr.Image(label="Style Image", type='pil')
80
+ target = gr.Dropdown(["Load original IP-Adapter","Load only style blocks","Load style+layout block"], label="LORA Model", info="Which finetuned model to use?")
81
+ prompt = gr.Textbox(label="Prompt",value="a cat, masterpiece, best quality, high quality")
82
+ n_prompt = gr.Textbox(label="Neg Prompt",value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
83
+ scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="scale")
84
+ guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance_scale")
85
+ num_samples= gr.Slider(minimum=1,maximum=3.0, step=1.0,value=1.0, label="num_samples")
86
+ num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=30, label="num_inference_steps")
87
+ seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
88
+ generate_button = gr.Button("Generate Image")
89
+ with gr.Column():
90
+ generated_image = gr.Gallery(label="Generated Image")
91
+
92
+ generate_button.click(fn=create_image, inputs=[image_pil,target,prompt,n_prompt,scale, guidance_scale,num_samples,num_inference_steps,seed],
93
+ outputs=[generated_image])
94
+
95
+ block.launch(max_threads=10)
96
+
97
+
98
+
assets/0.jpg ADDED
assets/2.jpg ADDED
assets/3.jpg ADDED
assets/comparison.png ADDED

Git LFS Details

  • SHA256: c7d24b8b9c919e656c706f1880c92f2e06eb992dfb87ee38e2f5e9ac93321867
  • Pointer size: 132 Bytes
  • Size of remote file: 8.16 MB
assets/example1.png ADDED

Git LFS Details

  • SHA256: e1d147d2d68e56952fe7478aeb44bd5aa4cd04e4db8bc6c26bf38af2f45d90fc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.28 MB
assets/example2.png ADDED

Git LFS Details

  • SHA256: de146f0e9f538417ba0d9f220cd145803b42473d4e8c5c809e64b0c8118cd453
  • Pointer size: 132 Bytes
  • Size of remote file: 3.55 MB
assets/example3.png ADDED

Git LFS Details

  • SHA256: 25cd77e50bc38c58df8f6de6f73f63d29b7f2903835e2b982fedc3bb6a7c937a
  • Pointer size: 132 Bytes
  • Size of remote file: 4.41 MB
assets/pipe.png ADDED
assets/subtraction.png ADDED

Git LFS Details

  • SHA256: 150d5d512e138a56c70161a17dff34a1a1b70c886f2a62d0b1237c9dee277189
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
assets/tree.png ADDED

Git LFS Details

  • SHA256: 137ea0afd370ad9ff9d8827d4045f7f00c10c9378fd7feb5f70e0a91ce95c5e3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
flagged/Style Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png ADDED

Git LFS Details

  • SHA256: e10e03b7177d9725b44dce4ca20568f266b5b1481215ca34a81bbc9a3ad6cd71
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Style Image,Prompt,Negative Prompt,Scale,guidance_scale,num_samples,num_inference_steps,Seed Value,Processed Image,flag,username,timestamp
2
+ /home/rnd/Documents/Ameer/InstantStyle/flagged/Style Image/4f12bf3724d50ac7ab9b87ce0e3fd4e327ed3ba0/tmp50v2kwjw.png,dfgdfgdf,"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",1,5,1,30,1,,,,2024-04-05 00:34:42.130755
ip_adapter/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ ]
ip_adapter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (313 Bytes). View file
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (9.93 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+
102
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+
105
+ def __call__(
106
+ self,
107
+ attn,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ attention_mask=None,
111
+ temb=None,
112
+ ):
113
+ residual = hidden_states
114
+
115
+ if attn.spatial_norm is not None:
116
+ hidden_states = attn.spatial_norm(hidden_states, temb)
117
+
118
+ input_ndim = hidden_states.ndim
119
+
120
+ if input_ndim == 4:
121
+ batch_size, channel, height, width = hidden_states.shape
122
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
123
+
124
+ batch_size, sequence_length, _ = (
125
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
126
+ )
127
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
128
+
129
+ if attn.group_norm is not None:
130
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
131
+
132
+ query = attn.to_q(hidden_states)
133
+
134
+ if encoder_hidden_states is None:
135
+ encoder_hidden_states = hidden_states
136
+ else:
137
+ # get encoder_hidden_states, ip_hidden_states
138
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
139
+ encoder_hidden_states, ip_hidden_states = (
140
+ encoder_hidden_states[:, :end_pos, :],
141
+ encoder_hidden_states[:, end_pos:, :],
142
+ )
143
+ if attn.norm_cross:
144
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
145
+
146
+ key = attn.to_k(encoder_hidden_states)
147
+ value = attn.to_v(encoder_hidden_states)
148
+
149
+ query = attn.head_to_batch_dim(query)
150
+ key = attn.head_to_batch_dim(key)
151
+ value = attn.head_to_batch_dim(value)
152
+
153
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
154
+ hidden_states = torch.bmm(attention_probs, value)
155
+ hidden_states = attn.batch_to_head_dim(hidden_states)
156
+
157
+ # for ip-adapter
158
+ ip_key = self.to_k_ip(ip_hidden_states)
159
+ ip_value = self.to_v_ip(ip_hidden_states)
160
+
161
+ ip_key = attn.head_to_batch_dim(ip_key)
162
+ ip_value = attn.head_to_batch_dim(ip_value)
163
+
164
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
165
+ self.attn_map = ip_attention_probs
166
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
167
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
168
+
169
+ hidden_states = hidden_states + self.scale * ip_hidden_states
170
+
171
+ # linear proj
172
+ hidden_states = attn.to_out[0](hidden_states)
173
+ # dropout
174
+ hidden_states = attn.to_out[1](hidden_states)
175
+
176
+ if input_ndim == 4:
177
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
178
+
179
+ if attn.residual_connection:
180
+ hidden_states = hidden_states + residual
181
+
182
+ hidden_states = hidden_states / attn.rescale_output_factor
183
+
184
+ return hidden_states
185
+
186
+
187
+ class AttnProcessor2_0(torch.nn.Module):
188
+ r"""
189
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ hidden_size=None,
195
+ cross_attention_dim=None,
196
+ ):
197
+ super().__init__()
198
+ if not hasattr(F, "scaled_dot_product_attention"):
199
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
200
+
201
+ def __call__(
202
+ self,
203
+ attn,
204
+ hidden_states,
205
+ encoder_hidden_states=None,
206
+ attention_mask=None,
207
+ temb=None,
208
+ ):
209
+ residual = hidden_states
210
+
211
+ if attn.spatial_norm is not None:
212
+ hidden_states = attn.spatial_norm(hidden_states, temb)
213
+
214
+ input_ndim = hidden_states.ndim
215
+
216
+ if input_ndim == 4:
217
+ batch_size, channel, height, width = hidden_states.shape
218
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
219
+
220
+ batch_size, sequence_length, _ = (
221
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
222
+ )
223
+
224
+ if attention_mask is not None:
225
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
226
+ # scaled_dot_product_attention expects attention_mask shape to be
227
+ # (batch, heads, source_length, target_length)
228
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
229
+
230
+ if attn.group_norm is not None:
231
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
232
+
233
+ query = attn.to_q(hidden_states)
234
+
235
+ if encoder_hidden_states is None:
236
+ encoder_hidden_states = hidden_states
237
+ elif attn.norm_cross:
238
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
239
+
240
+ key = attn.to_k(encoder_hidden_states)
241
+ value = attn.to_v(encoder_hidden_states)
242
+
243
+ inner_dim = key.shape[-1]
244
+ head_dim = inner_dim // attn.heads
245
+
246
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
247
+
248
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+
251
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
252
+ # TODO: add support for attn.scale when we move to Torch 2.1
253
+ hidden_states = F.scaled_dot_product_attention(
254
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255
+ )
256
+
257
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258
+ hidden_states = hidden_states.to(query.dtype)
259
+
260
+ # linear proj
261
+ hidden_states = attn.to_out[0](hidden_states)
262
+ # dropout
263
+ hidden_states = attn.to_out[1](hidden_states)
264
+
265
+ if input_ndim == 4:
266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
267
+
268
+ if attn.residual_connection:
269
+ hidden_states = hidden_states + residual
270
+
271
+ hidden_states = hidden_states / attn.rescale_output_factor
272
+
273
+ return hidden_states
274
+
275
+
276
+ class IPAttnProcessor2_0(torch.nn.Module):
277
+ r"""
278
+ Attention processor for IP-Adapater for PyTorch 2.0.
279
+ Args:
280
+ hidden_size (`int`):
281
+ The hidden size of the attention layer.
282
+ cross_attention_dim (`int`):
283
+ The number of channels in the `encoder_hidden_states`.
284
+ scale (`float`, defaults to 1.0):
285
+ the weight scale of image prompt.
286
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
287
+ The context length of the image features.
288
+ """
289
+
290
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
291
+ super().__init__()
292
+
293
+ if not hasattr(F, "scaled_dot_product_attention"):
294
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
295
+
296
+ self.hidden_size = hidden_size
297
+ self.cross_attention_dim = cross_attention_dim
298
+ self.scale = scale
299
+ self.num_tokens = num_tokens
300
+
301
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
302
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
303
+
304
+ def __call__(
305
+ self,
306
+ attn,
307
+ hidden_states,
308
+ encoder_hidden_states=None,
309
+ attention_mask=None,
310
+ temb=None,
311
+ ):
312
+ residual = hidden_states
313
+
314
+ if attn.spatial_norm is not None:
315
+ hidden_states = attn.spatial_norm(hidden_states, temb)
316
+
317
+ input_ndim = hidden_states.ndim
318
+
319
+ if input_ndim == 4:
320
+ batch_size, channel, height, width = hidden_states.shape
321
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
322
+
323
+ batch_size, sequence_length, _ = (
324
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
325
+ )
326
+
327
+ if attention_mask is not None:
328
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
329
+ # scaled_dot_product_attention expects attention_mask shape to be
330
+ # (batch, heads, source_length, target_length)
331
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
332
+
333
+ if attn.group_norm is not None:
334
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
335
+
336
+ query = attn.to_q(hidden_states)
337
+
338
+ if encoder_hidden_states is None:
339
+ encoder_hidden_states = hidden_states
340
+ else:
341
+ # get encoder_hidden_states, ip_hidden_states
342
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
343
+ encoder_hidden_states, ip_hidden_states = (
344
+ encoder_hidden_states[:, :end_pos, :],
345
+ encoder_hidden_states[:, end_pos:, :],
346
+ )
347
+ if attn.norm_cross:
348
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
349
+
350
+ key = attn.to_k(encoder_hidden_states)
351
+ value = attn.to_v(encoder_hidden_states)
352
+
353
+ inner_dim = key.shape[-1]
354
+ head_dim = inner_dim // attn.heads
355
+
356
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
357
+
358
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
359
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
+
361
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
362
+ # TODO: add support for attn.scale when we move to Torch 2.1
363
+ hidden_states = F.scaled_dot_product_attention(
364
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
365
+ )
366
+
367
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
368
+ hidden_states = hidden_states.to(query.dtype)
369
+
370
+ # for ip-adapter
371
+ ip_key = self.to_k_ip(ip_hidden_states)
372
+ ip_value = self.to_v_ip(ip_hidden_states)
373
+
374
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
375
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
376
+
377
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
378
+ # TODO: add support for attn.scale when we move to Torch 2.1
379
+ ip_hidden_states = F.scaled_dot_product_attention(
380
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
381
+ )
382
+ with torch.no_grad():
383
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
384
+ #print(self.attn_map.shape)
385
+
386
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
387
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
388
+
389
+ hidden_states = hidden_states + self.scale * ip_hidden_states
390
+
391
+ # linear proj
392
+ hidden_states = attn.to_out[0](hidden_states)
393
+ # dropout
394
+ hidden_states = attn.to_out[1](hidden_states)
395
+
396
+ if input_ndim == 4:
397
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
398
+
399
+ if attn.residual_connection:
400
+ hidden_states = hidden_states + residual
401
+
402
+ hidden_states = hidden_states / attn.rescale_output_factor
403
+
404
+ return hidden_states
405
+
406
+
407
+ ## for controlnet
408
+ class CNAttnProcessor:
409
+ r"""
410
+ Default processor for performing attention-related computations.
411
+ """
412
+
413
+ def __init__(self, num_tokens=4):
414
+ self.num_tokens = num_tokens
415
+
416
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
417
+ residual = hidden_states
418
+
419
+ if attn.spatial_norm is not None:
420
+ hidden_states = attn.spatial_norm(hidden_states, temb)
421
+
422
+ input_ndim = hidden_states.ndim
423
+
424
+ if input_ndim == 4:
425
+ batch_size, channel, height, width = hidden_states.shape
426
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
427
+
428
+ batch_size, sequence_length, _ = (
429
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
430
+ )
431
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
432
+
433
+ if attn.group_norm is not None:
434
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
435
+
436
+ query = attn.to_q(hidden_states)
437
+
438
+ if encoder_hidden_states is None:
439
+ encoder_hidden_states = hidden_states
440
+ else:
441
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
442
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
443
+ if attn.norm_cross:
444
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
445
+
446
+ key = attn.to_k(encoder_hidden_states)
447
+ value = attn.to_v(encoder_hidden_states)
448
+
449
+ query = attn.head_to_batch_dim(query)
450
+ key = attn.head_to_batch_dim(key)
451
+ value = attn.head_to_batch_dim(value)
452
+
453
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
454
+ hidden_states = torch.bmm(attention_probs, value)
455
+ hidden_states = attn.batch_to_head_dim(hidden_states)
456
+
457
+ # linear proj
458
+ hidden_states = attn.to_out[0](hidden_states)
459
+ # dropout
460
+ hidden_states = attn.to_out[1](hidden_states)
461
+
462
+ if input_ndim == 4:
463
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
464
+
465
+ if attn.residual_connection:
466
+ hidden_states = hidden_states + residual
467
+
468
+ hidden_states = hidden_states / attn.rescale_output_factor
469
+
470
+ return hidden_states
471
+
472
+
473
+ class CNAttnProcessor2_0:
474
+ r"""
475
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
476
+ """
477
+
478
+ def __init__(self, num_tokens=4):
479
+ if not hasattr(F, "scaled_dot_product_attention"):
480
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
481
+ self.num_tokens = num_tokens
482
+
483
+ def __call__(
484
+ self,
485
+ attn,
486
+ hidden_states,
487
+ encoder_hidden_states=None,
488
+ attention_mask=None,
489
+ temb=None,
490
+ ):
491
+ residual = hidden_states
492
+
493
+ if attn.spatial_norm is not None:
494
+ hidden_states = attn.spatial_norm(hidden_states, temb)
495
+
496
+ input_ndim = hidden_states.ndim
497
+
498
+ if input_ndim == 4:
499
+ batch_size, channel, height, width = hidden_states.shape
500
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
501
+
502
+ batch_size, sequence_length, _ = (
503
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
504
+ )
505
+
506
+ if attention_mask is not None:
507
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
508
+ # scaled_dot_product_attention expects attention_mask shape to be
509
+ # (batch, heads, source_length, target_length)
510
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
511
+
512
+ if attn.group_norm is not None:
513
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
514
+
515
+ query = attn.to_q(hidden_states)
516
+
517
+ if encoder_hidden_states is None:
518
+ encoder_hidden_states = hidden_states
519
+ else:
520
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
521
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
522
+ if attn.norm_cross:
523
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
524
+
525
+ key = attn.to_k(encoder_hidden_states)
526
+ value = attn.to_v(encoder_hidden_states)
527
+
528
+ inner_dim = key.shape[-1]
529
+ head_dim = inner_dim // attn.heads
530
+
531
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
532
+
533
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
534
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
535
+
536
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
537
+ # TODO: add support for attn.scale when we move to Torch 2.1
538
+ hidden_states = F.scaled_dot_product_attention(
539
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
540
+ )
541
+
542
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
543
+ hidden_states = hidden_states.to(query.dtype)
544
+
545
+ # linear proj
546
+ hidden_states = attn.to_out[0](hidden_states)
547
+ # dropout
548
+ hidden_states = attn.to_out[1](hidden_states)
549
+
550
+ if input_ndim == 4:
551
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
552
+
553
+ if attn.residual_connection:
554
+ hidden_states = hidden_states + residual
555
+
556
+ hidden_states = hidden_states / attn.rescale_output_factor
557
+
558
+ return hidden_states
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .utils import is_torch2_available, get_generator
12
+
13
+ if is_torch2_available():
14
+ from .attention_processor import (
15
+ AttnProcessor2_0 as AttnProcessor,
16
+ )
17
+ from .attention_processor import (
18
+ CNAttnProcessor2_0 as CNAttnProcessor,
19
+ )
20
+ from .attention_processor import (
21
+ IPAttnProcessor2_0 as IPAttnProcessor,
22
+ )
23
+ else:
24
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25
+ from .resampler import Resampler
26
+
27
+
28
+ class ImageProjModel(torch.nn.Module):
29
+ """Projection Model"""
30
+
31
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32
+ super().__init__()
33
+
34
+ self.generator = None
35
+ self.cross_attention_dim = cross_attention_dim
36
+ self.clip_extra_context_tokens = clip_extra_context_tokens
37
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
39
+
40
+ def forward(self, image_embeds):
41
+ embeds = image_embeds
42
+ clip_extra_context_tokens = self.proj(embeds).reshape(
43
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
44
+ )
45
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46
+ return clip_extra_context_tokens
47
+
48
+
49
+ class MLPProjModel(torch.nn.Module):
50
+ """SD model with image prompt"""
51
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
52
+ super().__init__()
53
+
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
56
+ torch.nn.GELU(),
57
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
58
+ torch.nn.LayerNorm(cross_attention_dim)
59
+ )
60
+
61
+ def forward(self, image_embeds):
62
+ clip_extra_context_tokens = self.proj(image_embeds)
63
+ return clip_extra_context_tokens
64
+
65
+
66
+ class IPAdapter:
67
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["blocks"]):
68
+ self.device = device
69
+ self.image_encoder_path = image_encoder_path
70
+ self.ip_ckpt = ip_ckpt
71
+ self.num_tokens = num_tokens
72
+ self.target_blocks = target_blocks
73
+
74
+ self.pipe = sd_pipe.to(self.device)
75
+ self.set_ip_adapter()
76
+
77
+ # load image encoder
78
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float16
80
+ )
81
+ self.clip_image_processor = CLIPImageProcessor()
82
+ # image proj model
83
+ self.image_proj_model = self.init_proj()
84
+
85
+ self.load_ip_adapter()
86
+
87
+ def init_proj(self):
88
+ image_proj_model = ImageProjModel(
89
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
90
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
91
+ clip_extra_context_tokens=self.num_tokens,
92
+ ).to(self.device, dtype=torch.float16)
93
+ return image_proj_model
94
+
95
+ def set_ip_adapter(self):
96
+ unet = self.pipe.unet
97
+ attn_procs = {}
98
+ for name in unet.attn_processors.keys():
99
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
100
+ if name.startswith("mid_block"):
101
+ hidden_size = unet.config.block_out_channels[-1]
102
+ elif name.startswith("up_blocks"):
103
+ block_id = int(name[len("up_blocks.")])
104
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
105
+ elif name.startswith("down_blocks"):
106
+ block_id = int(name[len("down_blocks.")])
107
+ hidden_size = unet.config.block_out_channels[block_id]
108
+ if cross_attention_dim is None:
109
+ attn_procs[name] = AttnProcessor()
110
+ else:
111
+ selected = False
112
+ for block_name in self.target_blocks:
113
+ if block_name in name:
114
+ selected = True
115
+ break
116
+ if selected:
117
+ attn_procs[name] = IPAttnProcessor(
118
+ hidden_size=hidden_size,
119
+ cross_attention_dim=cross_attention_dim,
120
+ scale=1.0,
121
+ num_tokens=self.num_tokens,
122
+ ).to(self.device, dtype=torch.float16)
123
+ else:
124
+ attn_procs[name] = AttnProcessor(
125
+ hidden_size=hidden_size,
126
+ cross_attention_dim=cross_attention_dim,
127
+ ).to(self.device, dtype=torch.float16)
128
+ unet.set_attn_processor(attn_procs)
129
+ if hasattr(self.pipe, "controlnet"):
130
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
131
+ for controlnet in self.pipe.controlnet.nets:
132
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
133
+ else:
134
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
135
+
136
+ def load_ip_adapter(self):
137
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
138
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
139
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
140
+ for key in f.keys():
141
+ if key.startswith("image_proj."):
142
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
143
+ elif key.startswith("ip_adapter."):
144
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
145
+ else:
146
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
147
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
148
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
149
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
150
+
151
+ @torch.inference_mode()
152
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
153
+ if pil_image is not None:
154
+ if isinstance(pil_image, Image.Image):
155
+ pil_image = [pil_image]
156
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
157
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
158
+ else:
159
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
160
+
161
+ if content_prompt_embeds is not None:
162
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
163
+
164
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
165
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
166
+ return image_prompt_embeds, uncond_image_prompt_embeds
167
+
168
+ def set_scale(self, scale):
169
+ for attn_processor in self.pipe.unet.attn_processors.values():
170
+ if isinstance(attn_processor, IPAttnProcessor):
171
+ attn_processor.scale = scale
172
+
173
+ def generate(
174
+ self,
175
+ pil_image=None,
176
+ clip_image_embeds=None,
177
+ prompt=None,
178
+ negative_prompt=None,
179
+ scale=1.0,
180
+ num_samples=4,
181
+ seed=None,
182
+ guidance_scale=7.5,
183
+ num_inference_steps=30,
184
+ neg_content_prompt=None,
185
+ neg_content_scale=1.0,
186
+ **kwargs,
187
+ ):
188
+ self.set_scale(scale)
189
+
190
+ if pil_image is not None:
191
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
192
+ else:
193
+ num_prompts = clip_image_embeds.size(0)
194
+
195
+ if prompt is None:
196
+ prompt = "best quality, high quality"
197
+ if negative_prompt is None:
198
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
199
+
200
+ if not isinstance(prompt, List):
201
+ prompt = [prompt] * num_prompts
202
+ if not isinstance(negative_prompt, List):
203
+ negative_prompt = [negative_prompt] * num_prompts
204
+
205
+ if neg_content_prompt is not None:
206
+ with torch.inference_mode():
207
+ (
208
+ prompt_embeds_, # torch.Size([1, 77, 2048])
209
+ negative_prompt_embeds_,
210
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
211
+ negative_pooled_prompt_embeds_,
212
+ ) = self.pipe.encode_prompt(
213
+ neg_content_prompt,
214
+ num_images_per_prompt=num_samples,
215
+ do_classifier_free_guidance=True,
216
+ negative_prompt=negative_prompt,
217
+ )
218
+ pooled_prompt_embeds_ *= neg_content_scale
219
+ else:
220
+ pooled_prompt_embeds_ = None
221
+
222
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
223
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_
224
+ )
225
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
226
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
227
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
228
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
229
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
230
+
231
+ with torch.inference_mode():
232
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
233
+ prompt,
234
+ device=self.device,
235
+ num_images_per_prompt=num_samples,
236
+ do_classifier_free_guidance=True,
237
+ negative_prompt=negative_prompt,
238
+ )
239
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
240
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
241
+
242
+ generator = get_generator(seed, self.device)
243
+
244
+ images = self.pipe(
245
+ prompt_embeds=prompt_embeds,
246
+ negative_prompt_embeds=negative_prompt_embeds,
247
+ guidance_scale=guidance_scale,
248
+ num_inference_steps=num_inference_steps,
249
+ generator=generator,
250
+ **kwargs,
251
+ ).images
252
+
253
+ return images
254
+
255
+
256
+ class IPAdapterXL(IPAdapter):
257
+ """SDXL"""
258
+
259
+ def generate(
260
+ self,
261
+ pil_image,
262
+ prompt=None,
263
+ negative_prompt=None,
264
+ scale=1.0,
265
+ num_samples=4,
266
+ seed=None,
267
+ num_inference_steps=30,
268
+ neg_content_prompt=None,
269
+ neg_content_scale=1.0,
270
+ **kwargs,
271
+ ):
272
+ self.set_scale(scale)
273
+
274
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
275
+
276
+ if prompt is None:
277
+ prompt = "best quality, high quality"
278
+ if negative_prompt is None:
279
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
280
+
281
+ if not isinstance(prompt, List):
282
+ prompt = [prompt] * num_prompts
283
+ if not isinstance(negative_prompt, List):
284
+ negative_prompt = [negative_prompt] * num_prompts
285
+
286
+ if neg_content_prompt is not None:
287
+ with torch.inference_mode():
288
+ (
289
+ prompt_embeds_, # torch.Size([1, 77, 2048])
290
+ negative_prompt_embeds_,
291
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
292
+ negative_pooled_prompt_embeds_,
293
+ ) = self.pipe.encode_prompt(
294
+ neg_content_prompt,
295
+ num_images_per_prompt=num_samples,
296
+ do_classifier_free_guidance=True,
297
+ negative_prompt=negative_prompt,
298
+ )
299
+ pooled_prompt_embeds_ *= neg_content_scale
300
+ else:
301
+ pooled_prompt_embeds_ = None
302
+
303
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
304
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
305
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
306
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
307
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
308
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
309
+
310
+ with torch.inference_mode():
311
+ (
312
+ prompt_embeds,
313
+ negative_prompt_embeds,
314
+ pooled_prompt_embeds,
315
+ negative_pooled_prompt_embeds,
316
+ ) = self.pipe.encode_prompt(
317
+ prompt,
318
+ num_images_per_prompt=num_samples,
319
+ do_classifier_free_guidance=True,
320
+ negative_prompt=negative_prompt,
321
+ )
322
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
323
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
324
+
325
+ self.generator = get_generator(seed, self.device)
326
+
327
+ images = self.pipe(
328
+ prompt_embeds=prompt_embeds,
329
+ negative_prompt_embeds=negative_prompt_embeds,
330
+ pooled_prompt_embeds=pooled_prompt_embeds,
331
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
332
+ num_inference_steps=num_inference_steps,
333
+ generator=self.generator,
334
+ **kwargs,
335
+ ).images
336
+
337
+ return images
338
+
339
+
340
+ class IPAdapterPlus(IPAdapter):
341
+ """IP-Adapter with fine-grained features"""
342
+
343
+ def init_proj(self):
344
+ image_proj_model = Resampler(
345
+ dim=self.pipe.unet.config.cross_attention_dim,
346
+ depth=4,
347
+ dim_head=64,
348
+ heads=12,
349
+ num_queries=self.num_tokens,
350
+ embedding_dim=self.image_encoder.config.hidden_size,
351
+ output_dim=self.pipe.unet.config.cross_attention_dim,
352
+ ff_mult=4,
353
+ ).to(self.device, dtype=torch.float16)
354
+ return image_proj_model
355
+
356
+ @torch.inference_mode()
357
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
358
+ if isinstance(pil_image, Image.Image):
359
+ pil_image = [pil_image]
360
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
361
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
362
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
363
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
364
+ uncond_clip_image_embeds = self.image_encoder(
365
+ torch.zeros_like(clip_image), output_hidden_states=True
366
+ ).hidden_states[-2]
367
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
368
+ return image_prompt_embeds, uncond_image_prompt_embeds
369
+
370
+
371
+ class IPAdapterFull(IPAdapterPlus):
372
+ """IP-Adapter with full features"""
373
+
374
+ def init_proj(self):
375
+ image_proj_model = MLPProjModel(
376
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
377
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
378
+ ).to(self.device, dtype=torch.float16)
379
+ return image_proj_model
380
+
381
+
382
+ class IPAdapterPlusXL(IPAdapter):
383
+ """SDXL"""
384
+
385
+ def init_proj(self):
386
+ image_proj_model = Resampler(
387
+ dim=1280,
388
+ depth=4,
389
+ dim_head=64,
390
+ heads=20,
391
+ num_queries=self.num_tokens,
392
+ embedding_dim=self.image_encoder.config.hidden_size,
393
+ output_dim=self.pipe.unet.config.cross_attention_dim,
394
+ ff_mult=4,
395
+ ).to(self.device, dtype=torch.float16)
396
+ return image_proj_model
397
+
398
+ @torch.inference_mode()
399
+ def get_image_embeds(self, pil_image):
400
+ if isinstance(pil_image, Image.Image):
401
+ pil_image = [pil_image]
402
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
403
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
404
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
405
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
406
+ uncond_clip_image_embeds = self.image_encoder(
407
+ torch.zeros_like(clip_image), output_hidden_states=True
408
+ ).hidden_states[-2]
409
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
410
+ return image_prompt_embeds, uncond_image_prompt_embeds
411
+
412
+ def generate(
413
+ self,
414
+ pil_image,
415
+ prompt=None,
416
+ negative_prompt=None,
417
+ scale=1.0,
418
+ num_samples=4,
419
+ seed=None,
420
+ num_inference_steps=30,
421
+ **kwargs,
422
+ ):
423
+ self.set_scale(scale)
424
+
425
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
426
+
427
+ if prompt is None:
428
+ prompt = "best quality, high quality"
429
+ if negative_prompt is None:
430
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
431
+
432
+ if not isinstance(prompt, List):
433
+ prompt = [prompt] * num_prompts
434
+ if not isinstance(negative_prompt, List):
435
+ negative_prompt = [negative_prompt] * num_prompts
436
+
437
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
438
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
439
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
440
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
441
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
442
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
443
+
444
+ with torch.inference_mode():
445
+ (
446
+ prompt_embeds,
447
+ negative_prompt_embeds,
448
+ pooled_prompt_embeds,
449
+ negative_pooled_prompt_embeds,
450
+ ) = self.pipe.encode_prompt(
451
+ prompt,
452
+ num_images_per_prompt=num_samples,
453
+ do_classifier_free_guidance=True,
454
+ negative_prompt=negative_prompt,
455
+ )
456
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
457
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
458
+
459
+ generator = get_generator(seed, self.device)
460
+
461
+ images = self.pipe(
462
+ prompt_embeds=prompt_embeds,
463
+ negative_prompt_embeds=negative_prompt_embeds,
464
+ pooled_prompt_embeds=pooled_prompt_embeds,
465
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
466
+ num_inference_steps=num_inference_steps,
467
+ generator=generator,
468
+ **kwargs,
469
+ ).images
470
+
471
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def get_generator(seed, device):
84
+
85
+ if seed is not None:
86
+ if isinstance(seed, list):
87
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88
+ else:
89
+ generator = torch.Generator(device).manual_seed(seed)
90
+ else:
91
+ generator = None
92
+
93
+ return generator
models/image_encoder/config.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPVisionModelWithProjection"
4
+ ],
5
+ "_name_or_path": "",
6
+ "add_cross_attention": false,
7
+ "architectures": null,
8
+ "attention_dropout": 0.0,
9
+ "bad_words_ids": null,
10
+ "begin_suppress_tokens": null,
11
+ "bos_token_id": null,
12
+ "chunk_size_feed_forward": 0,
13
+ "cross_attention_hidden_size": null,
14
+ "decoder_start_token_id": null,
15
+ "diversity_penalty": 0.0,
16
+ "do_sample": false,
17
+ "dropout": 0.0,
18
+ "early_stopping": false,
19
+ "encoder_no_repeat_ngram_size": 0,
20
+ "eos_token_id": null,
21
+ "exponential_decay_length_penalty": null,
22
+ "finetuning_task": null,
23
+ "forced_bos_token_id": null,
24
+ "forced_eos_token_id": null,
25
+ "hidden_act": "gelu",
26
+ "hidden_size": 1664,
27
+ "id2label": {
28
+ "0": "LABEL_0",
29
+ "1": "LABEL_1"
30
+ },
31
+ "image_size": 224,
32
+ "initializer_factor": 1.0,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 8192,
35
+ "is_decoder": false,
36
+ "is_encoder_decoder": false,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1
40
+ },
41
+ "layer_norm_eps": 1e-05,
42
+ "length_penalty": 1.0,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "model_type": "clip_vision_model",
46
+ "no_repeat_ngram_size": 0,
47
+ "num_attention_heads": 16,
48
+ "num_beam_groups": 1,
49
+ "num_beams": 1,
50
+ "num_channels": 3,
51
+ "num_hidden_layers": 48,
52
+ "num_return_sequences": 1,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_scores": false,
56
+ "pad_token_id": null,
57
+ "patch_size": 14,
58
+ "prefix": null,
59
+ "problem_type": null,
60
+ "pruned_heads": {},
61
+ "remove_invalid_values": false,
62
+ "repetition_penalty": 1.0,
63
+ "return_dict": true,
64
+ "return_dict_in_generate": false,
65
+ "sep_token_id": null,
66
+ "suppress_tokens": null,
67
+ "task_specific_params": null,
68
+ "temperature": 1.0,
69
+ "tf_legacy_loss": false,
70
+ "tie_encoder_decoder": false,
71
+ "tie_word_embeddings": true,
72
+ "tokenizer_class": null,
73
+ "top_k": 50,
74
+ "top_p": 1.0,
75
+ "torch_dtype": null,
76
+ "torchscript": false,
77
+ "transformers_version": "4.24.0",
78
+ "typical_p": 1.0,
79
+ "use_bfloat16": false,
80
+ "projection_dim": 1280
81
+ }
models/image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:657723e09f46a7c3957df651601029f66b1748afb12b419816330f16ed45d64d
3
+ size 3689912664
models/image_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2999562fbc02f9dc0d9c0acb7cf0970ec3a9b2a578d7d05afe82191d606d2d80
3
+ size 3690112753
models/ip-adapter_sdxl.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7525f2731e9e86d1368e0b68467615d55dda459691965bdd7d37fa3d7fd84c12
3
+ size 702585097
result.png ADDED

Git LFS Details

  • SHA256: 4156f5f6670a8d53e9400621f5eb75b2e1d56a8c565f1870f93c68a29c4812bc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.86 MB