yuxi-liu-wired
commited on
Commit
•
97d6b59
1
Parent(s):
1b1d8c3
README tutorial.
Browse files- README.md +252 -4
- examples/example.ipynb +2 -13
- examples/requirements.txt +9 -0
- examples/style_embedding_tsne.png +0 -0
README.md
CHANGED
@@ -1,9 +1,257 @@
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
tags:
|
3 |
- model_hub_mixin
|
4 |
-
- pytorch_model_hub_mixin
|
5 |
---
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
- pytorch_model_hub_mixin
|
3 |
+
license: mit
|
4 |
+
language:
|
5 |
+
- en
|
6 |
+
base_model:
|
7 |
+
- openai/clip-vit-large-patch14
|
8 |
+
tags:
|
9 |
+
- art
|
10 |
+
- style
|
11 |
+
- clip
|
12 |
+
- image
|
13 |
+
- embedding
|
14 |
+
- vit
|
15 |
tags:
|
16 |
- model_hub_mixin
|
|
|
17 |
---
|
18 |
|
19 |
+
## Measuring Style Similarity in Diffusion Models
|
20 |
+
|
21 |
+
Cloned from [learn2phoenix/CSD](https://github.com/learn2phoenix/CSD?tab=readme-ov-file).
|
22 |
+
|
23 |
+
Their model (`csd-vit-l.pth`) downloaded from their [Google Drive](https://drive.google.com/file/d/1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46/view?usp=sharing).
|
24 |
+
|
25 |
+
The original Git Repo is in the `CSD` folder.
|
26 |
+
|
27 |
+
## Model architecture
|
28 |
+
|
29 |
+
The model CSD ("contrastive style descriptor") is initialized from the image encoder part of [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14). Let $f$ be the function implemented by the image encoder. $f$ is implemented as a vision Transformer, that takes an image, and converts it into a $1024$-dimensional real-valued vector. This is then followed by a single matrix ("projection matrix") of dimensions $1024 \times 768$, converting it to a CLIP-embedding vector.
|
30 |
+
|
31 |
+
Now, remove the projection matrix. This gives us $g: \text{Image} \to \R^{1024}$. The output from $g$ is the `feature vector`. Now, add in two more projection matrices of dimensions $1024 \times 768$. The output from one is the `style vector` and the other is the `content vector`. All parameters of the resulting model was then finetuned by [tadeephuy/GradientReversal](https://github.com/tadeephuy/GradientReversal) for content style disentanglement, resulting in the final model.
|
32 |
+
|
33 |
+
The original paper actually stated that they trained *two* models, and one of them was based on ViT-B, but they did not release it.
|
34 |
+
|
35 |
+
Also, despite the names `style vector` and `content vector`, I have noticed by visual inspection that both are basically equally good for style embedding. I don't know why, but I guess that's life?
|
36 |
+
|
37 |
+
## How to use it
|
38 |
+
|
39 |
+
### Quickstart
|
40 |
+
|
41 |
+
Go to `examples` and run the `example.ipynb` notebook, then run `tsne_visualization.py`. It will say something like `Running on http://127.0.0.1:49860`. Click that link and enjoy the pretty interactive picture.
|
42 |
+
|
43 |
+
![](examples/style_embedding_tsne.png)
|
44 |
+
|
45 |
+
### Loading the model
|
46 |
+
|
47 |
+
```python
|
48 |
+
import copy
|
49 |
+
import torch
|
50 |
+
import torch.nn as nn
|
51 |
+
import clip
|
52 |
+
from transformers import CLIPProcessor
|
53 |
+
from huggingface_hub import PyTorchModelHubMixin
|
54 |
+
from transformers import PretrainedConfig
|
55 |
+
|
56 |
+
class CSDCLIPConfig(PretrainedConfig):
|
57 |
+
model_type = "csd_clip"
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
name="csd_large",
|
62 |
+
embedding_dim=1024,
|
63 |
+
feature_dim=1024,
|
64 |
+
content_dim=768,
|
65 |
+
style_dim=768,
|
66 |
+
content_proj_head="default",
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
super().__init__(**kwargs)
|
70 |
+
self.name = name
|
71 |
+
self.embedding_dim = embedding_dim
|
72 |
+
self.content_proj_head = content_proj_head
|
73 |
+
self.task_specific_params = None # Add this line
|
74 |
+
|
75 |
+
class CSD_CLIP(nn.Module, PyTorchModelHubMixin):
|
76 |
+
"""backbone + projection head"""
|
77 |
+
def __init__(self, name='vit_large',content_proj_head='default'):
|
78 |
+
super(CSD_CLIP, self).__init__()
|
79 |
+
self.content_proj_head = content_proj_head
|
80 |
+
if name == 'vit_large':
|
81 |
+
clipmodel, _ = clip.load("ViT-L/14")
|
82 |
+
self.backbone = clipmodel.visual
|
83 |
+
self.embedding_dim = 1024
|
84 |
+
self.feature_dim = 1024
|
85 |
+
self.content_dim = 768
|
86 |
+
self.style_dim = 768
|
87 |
+
self.name = "csd_large"
|
88 |
+
elif name == 'vit_base':
|
89 |
+
clipmodel, _ = clip.load("ViT-B/16")
|
90 |
+
self.backbone = clipmodel.visual
|
91 |
+
self.embedding_dim = 768
|
92 |
+
self.feature_dim = 512
|
93 |
+
self.content_dim = 512
|
94 |
+
self.style_dim = 512
|
95 |
+
self.name = "csd_base"
|
96 |
+
else:
|
97 |
+
raise Exception('This model is not implemented')
|
98 |
+
|
99 |
+
self.last_layer_style = copy.deepcopy(self.backbone.proj)
|
100 |
+
self.last_layer_content = copy.deepcopy(self.backbone.proj)
|
101 |
+
|
102 |
+
self.backbone.proj = None
|
103 |
+
|
104 |
+
self.config = CSDCLIPConfig(
|
105 |
+
name=self.name,
|
106 |
+
embedding_dim=self.embedding_dim,
|
107 |
+
feature_dim=self.feature_dim,
|
108 |
+
content_dim=self.content_dim,
|
109 |
+
style_dim=self.style_dim,
|
110 |
+
content_proj_head=self.content_proj_head
|
111 |
+
)
|
112 |
+
|
113 |
+
def get_config(self):
|
114 |
+
return self.config.to_dict()
|
115 |
+
|
116 |
+
@property
|
117 |
+
def dtype(self):
|
118 |
+
return self.backbone.conv1.weight.dtype
|
119 |
+
|
120 |
+
@property
|
121 |
+
def device(self):
|
122 |
+
return next(self.parameters()).device
|
123 |
+
|
124 |
+
def forward(self, input_data):
|
125 |
+
|
126 |
+
feature = self.backbone(input_data)
|
127 |
+
|
128 |
+
style_output = feature @ self.last_layer_style
|
129 |
+
style_output = nn.functional.normalize(style_output, dim=1, p=2)
|
130 |
+
|
131 |
+
content_output = feature @ self.last_layer_content
|
132 |
+
content_output = nn.functional.normalize(content_output, dim=1, p=2)
|
133 |
+
|
134 |
+
return feature, content_output, style_output
|
135 |
+
|
136 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
137 |
+
model = CSD_CLIP.from_pretrained("yuxi-liu-wired/CSD")
|
138 |
+
model.to(device);
|
139 |
+
```
|
140 |
+
|
141 |
+
### Loading the pipeline
|
142 |
+
|
143 |
+
```python
|
144 |
+
import torch
|
145 |
+
from transformers import Pipeline
|
146 |
+
from typing import Union, List
|
147 |
+
from PIL import Image
|
148 |
+
|
149 |
+
class CSDCLIPPipeline(Pipeline):
|
150 |
+
def __init__(self, model, processor, device=None):
|
151 |
+
if device is None:
|
152 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
153 |
+
super().__init__(model=model, tokenizer=None, device=device)
|
154 |
+
self.processor = processor
|
155 |
+
|
156 |
+
def _sanitize_parameters(self, **kwargs):
|
157 |
+
return {}, {}, {}
|
158 |
+
|
159 |
+
def preprocess(self, images):
|
160 |
+
if isinstance(images, (str, Image.Image)):
|
161 |
+
images = [images]
|
162 |
+
|
163 |
+
processed = self.processor(images=images, return_tensors="pt", padding=True, truncation=True)
|
164 |
+
return {k: v.to(self.device) for k, v in processed.items()}
|
165 |
+
|
166 |
+
def _forward(self, model_inputs):
|
167 |
+
pixel_values = model_inputs['pixel_values'].to(self.model.dtype)
|
168 |
+
with torch.no_grad():
|
169 |
+
features, content_output, style_output = self.model(pixel_values)
|
170 |
+
return {"features": features, "content_output": content_output, "style_output": style_output}
|
171 |
+
|
172 |
+
def postprocess(self, model_outputs):
|
173 |
+
return {
|
174 |
+
"features": model_outputs["features"].cpu().numpy(),
|
175 |
+
"content_output": model_outputs["content_output"].cpu().numpy(),
|
176 |
+
"style_output": model_outputs["style_output"].cpu().numpy()
|
177 |
+
}
|
178 |
+
|
179 |
+
def __call__(self, images: Union[str, List[str], Image.Image, List[Image.Image]]):
|
180 |
+
return super().__call__(images)
|
181 |
+
|
182 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
183 |
+
pipeline = CSDCLIPPipeline(model=model, processor=processor, device=device)
|
184 |
+
```
|
185 |
+
|
186 |
+
### An example application
|
187 |
+
|
188 |
+
First, load the model and the pipeline, as described above. Then, run the following to load the [yuxi-liu-wired/style-content-grid-SDXL](https://huggingface.co/datasets/yuxi-liu-wired/style-content-grid-SDXL) dataset, embed its style vectors, which is then written to a `parquet` output file.
|
189 |
+
|
190 |
+
```python
|
191 |
+
import io
|
192 |
+
from PIL import Image
|
193 |
+
from datasets import load_dataset
|
194 |
+
import pandas as pd
|
195 |
+
from tqdm import tqdm
|
196 |
+
|
197 |
+
def to_jpeg(image):
|
198 |
+
buffered = io.BytesIO()
|
199 |
+
if image.mode not in ("RGB"):
|
200 |
+
image = image.convert("RGB")
|
201 |
+
image.save(buffered, format='JPEG')
|
202 |
+
return buffered.getvalue()
|
203 |
+
|
204 |
+
def scale_image(image, max_resolution):
|
205 |
+
if max(image.width, image.height) > max_resolution:
|
206 |
+
image = image.resize((max_resolution, int(image.height * max_resolution / image.width)))
|
207 |
+
return image
|
208 |
+
|
209 |
+
def process_dataset(pipeline, dataset_name, dataset_size=900, max_resolution=192):
|
210 |
+
dataset = load_dataset(dataset_name, split='train')
|
211 |
+
dataset = dataset.select(range(dataset_size))
|
212 |
+
|
213 |
+
# Print the column names
|
214 |
+
print("Dataset columns:", dataset.column_names)
|
215 |
+
|
216 |
+
# Initialize lists to store results
|
217 |
+
embeddings = []
|
218 |
+
jpeg_images = []
|
219 |
+
|
220 |
+
# Process each item in the dataset
|
221 |
+
for item in tqdm(dataset, desc="Processing images"):
|
222 |
+
try:
|
223 |
+
img = item['image']
|
224 |
+
|
225 |
+
# If img is a string (file path), load the image
|
226 |
+
if isinstance(img, str):
|
227 |
+
img = Image.open(img)
|
228 |
+
|
229 |
+
|
230 |
+
output = pipeline(img)
|
231 |
+
style_output = output["style_output"].squeeze(0)
|
232 |
+
|
233 |
+
img = scale_image(img, max_resolution)
|
234 |
+
jpeg_img = to_jpeg(img)
|
235 |
+
|
236 |
+
# Append results to lists
|
237 |
+
embeddings.append(style_output)
|
238 |
+
jpeg_images.append(jpeg_img)
|
239 |
+
except Exception as e:
|
240 |
+
print(f"Error processing item: {e}")
|
241 |
+
|
242 |
+
# Create a DataFrame with the results
|
243 |
+
df = pd.DataFrame({
|
244 |
+
'embedding': embeddings,
|
245 |
+
'image': jpeg_images
|
246 |
+
})
|
247 |
+
|
248 |
+
df.to_parquet('processed_dataset.parquet')
|
249 |
+
print("Processing complete. Results saved to 'processed_dataset.parquet'")
|
250 |
+
|
251 |
+
process_dataset(pipeline, "yuxi-liu-wired/style-content-grid-SDXL",
|
252 |
+
dataset_size=900, max_resolution=192)
|
253 |
+
```
|
254 |
+
|
255 |
+
After that, you can go to `examples` and run `tsne_visualization.py` to get an interactive Dash app browser for the images.
|
256 |
+
|
257 |
+
![](examples/style_embedding_tsne.png)
|
examples/example.ipynb
CHANGED
@@ -10,21 +10,12 @@
|
|
10 |
"outputs": [],
|
11 |
"source": [
|
12 |
"import copy\n",
|
13 |
-
"import os\n",
|
14 |
-
"import io\n",
|
15 |
"import torch\n",
|
16 |
"import torch.nn as nn\n",
|
17 |
"import clip\n",
|
18 |
-
"
|
19 |
-
"from PIL import Image\n",
|
20 |
-
"from tqdm import tqdm\n",
|
21 |
-
"import numpy as np\n",
|
22 |
-
"from transformers import Pipeline, CLIPProcessor, CLIPVisionModel\n",
|
23 |
"from huggingface_hub import PyTorchModelHubMixin\n",
|
24 |
-
"from typing import List, Union\n",
|
25 |
"from transformers import PretrainedConfig\n",
|
26 |
-
"import json\n",
|
27 |
-
"import safetensors\n",
|
28 |
"\n",
|
29 |
"class CSDCLIPConfig(PretrainedConfig):\n",
|
30 |
" model_type = \"csd_clip\"\n",
|
@@ -106,7 +97,7 @@
|
|
106 |
" \n",
|
107 |
" return feature, content_output, style_output\n",
|
108 |
"\n",
|
109 |
-
"device = 'cuda'\n",
|
110 |
"model = CSD_CLIP.from_pretrained(\"yuxi-liu-wired/CSD\")\n",
|
111 |
"model.to(device);"
|
112 |
]
|
@@ -188,10 +179,8 @@
|
|
188 |
"source": [
|
189 |
"import io\n",
|
190 |
"from PIL import Image\n",
|
191 |
-
"import requests\n",
|
192 |
"from datasets import load_dataset\n",
|
193 |
"import pandas as pd\n",
|
194 |
-
"import numpy as np\n",
|
195 |
"from tqdm import tqdm\n",
|
196 |
"\n",
|
197 |
"def to_jpeg(image):\n",
|
|
|
10 |
"outputs": [],
|
11 |
"source": [
|
12 |
"import copy\n",
|
|
|
|
|
13 |
"import torch\n",
|
14 |
"import torch.nn as nn\n",
|
15 |
"import clip\n",
|
16 |
+
"from transformers import CLIPProcessor\n",
|
|
|
|
|
|
|
|
|
17 |
"from huggingface_hub import PyTorchModelHubMixin\n",
|
|
|
18 |
"from transformers import PretrainedConfig\n",
|
|
|
|
|
19 |
"\n",
|
20 |
"class CSDCLIPConfig(PretrainedConfig):\n",
|
21 |
" model_type = \"csd_clip\"\n",
|
|
|
97 |
" \n",
|
98 |
" return feature, content_output, style_output\n",
|
99 |
"\n",
|
100 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
101 |
"model = CSD_CLIP.from_pretrained(\"yuxi-liu-wired/CSD\")\n",
|
102 |
"model.to(device);"
|
103 |
]
|
|
|
179 |
"source": [
|
180 |
"import io\n",
|
181 |
"from PIL import Image\n",
|
|
|
182 |
"from datasets import load_dataset\n",
|
183 |
"import pandas as pd\n",
|
|
|
184 |
"from tqdm import tqdm\n",
|
185 |
"\n",
|
186 |
"def to_jpeg(image):\n",
|
examples/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
numpy
|
3 |
+
sklearn
|
4 |
+
json
|
5 |
+
pillow
|
6 |
+
datasets
|
7 |
+
tqdm
|
8 |
+
clip @ git+https://github.com/openai/CLIP.git@main
|
9 |
+
torch
|
examples/style_embedding_tsne.png
ADDED