yue-here
commited on
Commit
•
5edc0a2
1
Parent(s):
d5851a1
first commit
Browse files- app.py +17 -0
- glyffuser_utils.py +174 -0
- t5.py +119 -0
app.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from glyffuser_utils import GlyffuserPipeline
|
3 |
+
|
4 |
+
pipeline = GlyffuserPipeline.from_pretrained("yuewu/glyffuser")
|
5 |
+
|
6 |
+
def infer(text):
|
7 |
+
generated_images = pipeline(
|
8 |
+
texts,
|
9 |
+
batch_size=1, # Generate one image at a time for each step
|
10 |
+
# generator=torch.Generator(device='cuda').manual_seed(config.seed), # Generator can be on GPU here
|
11 |
+
num_inference_steps=50
|
12 |
+
).images
|
13 |
+
|
14 |
+
return generated_images[0]
|
15 |
+
|
16 |
+
demo = gr.Interface(fn=infer, inputs="text", outputs="image")
|
17 |
+
demo.launch()
|
glyffuser_utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision import transforms as T
|
5 |
+
import t5
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from datasets import load_dataset
|
11 |
+
|
12 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
|
16 |
+
|
17 |
+
# Collator adjusted for local dataset
|
18 |
+
class Collator:
|
19 |
+
def __init__(self, image_size, text_label, image_label, name, channels):
|
20 |
+
self.text_label = text_label
|
21 |
+
self.image_label = image_label
|
22 |
+
self.name = name
|
23 |
+
self.channels = channels
|
24 |
+
self.transform = T.Compose([
|
25 |
+
T.Resize((image_size, image_size)),
|
26 |
+
T.ToTensor(),
|
27 |
+
])
|
28 |
+
|
29 |
+
def __call__(self, batch):
|
30 |
+
texts = []
|
31 |
+
masks = []
|
32 |
+
images = []
|
33 |
+
for item in batch:
|
34 |
+
try:
|
35 |
+
# Load image from local file
|
36 |
+
image_path = 'data/'+item[self.image_label] # Assuming this is a path to the image file
|
37 |
+
with Image.open(image_path) as img:
|
38 |
+
image = self.transform(img.convert(self.channels))
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Failed to process image {image_path}: {e}")
|
41 |
+
continue
|
42 |
+
|
43 |
+
# Encode the text
|
44 |
+
text, mask = t5.t5_encode_text(
|
45 |
+
[item[self.text_label]],
|
46 |
+
name=self.name,
|
47 |
+
return_attn_mask=True
|
48 |
+
)
|
49 |
+
texts.append(torch.squeeze(text))
|
50 |
+
masks.append(torch.squeeze(mask))
|
51 |
+
images.append(image)
|
52 |
+
|
53 |
+
if len(texts) == 0:
|
54 |
+
return None
|
55 |
+
|
56 |
+
# Are these strictly necessary?
|
57 |
+
texts = pad_sequence(texts, True)
|
58 |
+
masks = pad_sequence(masks, True)
|
59 |
+
|
60 |
+
newbatch = []
|
61 |
+
for i in range(len(texts)):
|
62 |
+
newbatch.append((images[i], texts[i], masks[i]))
|
63 |
+
|
64 |
+
return torch.utils.data.dataloader.default_collate(newbatch)
|
65 |
+
|
66 |
+
|
67 |
+
class GlyffuserPipeline(DiffusionPipeline):
|
68 |
+
r'''
|
69 |
+
Pipeline for text-to-image generation from the glyffuser model
|
70 |
+
|
71 |
+
Parameters:
|
72 |
+
unet (['UNet2DConditionModel'])
|
73 |
+
scheduler (['SchedulerMixin'])
|
74 |
+
text_encoder (['TextEncoder']) - T5 small
|
75 |
+
'''
|
76 |
+
def __init__(self, unet, scheduler):
|
77 |
+
super().__init__()
|
78 |
+
self.register_modules(
|
79 |
+
unet=unet,
|
80 |
+
scheduler=scheduler,
|
81 |
+
)
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def __call__(
|
85 |
+
self,
|
86 |
+
texts: List[str],
|
87 |
+
text_encoder: str = "google-t5/t5-small",
|
88 |
+
batch_size: int = 1,
|
89 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
90 |
+
num_inference_steps: int = 1000,
|
91 |
+
output_type: Optional[str] = "pil",
|
92 |
+
return_dict: bool = True,
|
93 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
94 |
+
'''
|
95 |
+
Docstring
|
96 |
+
'''
|
97 |
+
# Get text embeddings
|
98 |
+
# Encode the text
|
99 |
+
# text_embeddings = []
|
100 |
+
# for text in texts:
|
101 |
+
# embedding = t5.t5_encode_text(text, name=text_encoder)
|
102 |
+
# text_embeddings.append(torch.squeeze(embedding))
|
103 |
+
# text_embeddings = pad_sequence(text_embeddings, True)
|
104 |
+
|
105 |
+
batch_size = len(texts)
|
106 |
+
|
107 |
+
text_embeddings, masks = t5.t5_encode_text(texts, name=text_encoder, return_attn_mask=True)
|
108 |
+
|
109 |
+
# Sample gaussian noise to begin loop
|
110 |
+
if isinstance(self.unet.config.sample_size, int):
|
111 |
+
image_shape = (
|
112 |
+
batch_size,
|
113 |
+
self.unet.config.in_channels,
|
114 |
+
self.unet.config.sample_size,
|
115 |
+
self.unet.config.sample_size,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
119 |
+
|
120 |
+
|
121 |
+
# if self.device.type == "mps": # MPS is apple silicon
|
122 |
+
# # randn does not work reproducibly on mps
|
123 |
+
# image = randn_tensor(image_shape, generator=generator)
|
124 |
+
# image = image.to(self.device)
|
125 |
+
# else:
|
126 |
+
image = randn_tensor(image_shape, generator=generator, device=self.device)
|
127 |
+
|
128 |
+
# set step values
|
129 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
130 |
+
|
131 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
132 |
+
# 1. predict noise model_output
|
133 |
+
model_output = self.unet(
|
134 |
+
image,
|
135 |
+
t,
|
136 |
+
encoder_hidden_states=text_embeddings, # Add text encoding input
|
137 |
+
encoder_attention_mask=masks, # Add attention mask
|
138 |
+
return_dict=False
|
139 |
+
)[0] # <-- sample is an attribute of the BaseOutClass of type torch.FloatTensor
|
140 |
+
|
141 |
+
# 2. compute previous image: x_t -> x_t-1
|
142 |
+
image = self.scheduler.step(model_output, t, image, generator=generator, return_dict=False)[0]
|
143 |
+
|
144 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
145 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
146 |
+
if output_type == "pil":
|
147 |
+
image = self.numpy_to_pil(image)
|
148 |
+
|
149 |
+
if not return_dict:
|
150 |
+
return (image,)
|
151 |
+
|
152 |
+
return ImagePipelineOutput(images=image)
|
153 |
+
|
154 |
+
def make_grid(images, rows, cols):
|
155 |
+
w, h = images[0].size
|
156 |
+
grid = Image.new('RGB', size=(cols*w, rows*h))
|
157 |
+
for i, image in enumerate(images):
|
158 |
+
grid.paste(image, box=(i%cols*w, i//cols*h))
|
159 |
+
return grid
|
160 |
+
|
161 |
+
def evaluate(config, epoch, texts, pipeline):
|
162 |
+
images = pipeline(
|
163 |
+
texts,
|
164 |
+
batch_size = config.eval_batch_size,
|
165 |
+
generator=torch.Generator(device='cpu').manual_seed(config.seed), # Generator must be on CPU for sampling during training
|
166 |
+
).images
|
167 |
+
|
168 |
+
# Make a grid out of the images
|
169 |
+
image_grid = make_grid(images, rows=4, cols=4)
|
170 |
+
|
171 |
+
# Save the images
|
172 |
+
test_dir = os.path.join(config.output_dir, "samples")
|
173 |
+
os.makedirs(test_dir, exist_ok=True)
|
174 |
+
image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
t5.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from typing import List
|
4 |
+
from transformers import T5Tokenizer, T5EncoderModel, T5Config
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
transformers.logging.set_verbosity_error()
|
8 |
+
|
9 |
+
def exists(val):
|
10 |
+
return val is not None
|
11 |
+
|
12 |
+
def default(val, d):
|
13 |
+
if exists(val):
|
14 |
+
return val
|
15 |
+
return d() if callable(d) else d
|
16 |
+
|
17 |
+
# config
|
18 |
+
|
19 |
+
MAX_LENGTH = 256
|
20 |
+
|
21 |
+
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
|
22 |
+
|
23 |
+
T5_CONFIGS = {}
|
24 |
+
|
25 |
+
# singleton globals
|
26 |
+
|
27 |
+
def get_tokenizer(name):
|
28 |
+
tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
|
29 |
+
return tokenizer
|
30 |
+
|
31 |
+
def get_model(name):
|
32 |
+
model = T5EncoderModel.from_pretrained(name)
|
33 |
+
return model
|
34 |
+
|
35 |
+
def get_model_and_tokenizer(name):
|
36 |
+
global T5_CONFIGS
|
37 |
+
|
38 |
+
if name not in T5_CONFIGS:
|
39 |
+
T5_CONFIGS[name] = dict()
|
40 |
+
if "model" not in T5_CONFIGS[name]:
|
41 |
+
T5_CONFIGS[name]["model"] = get_model(name)
|
42 |
+
if "tokenizer" not in T5_CONFIGS[name]:
|
43 |
+
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
|
44 |
+
|
45 |
+
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
|
46 |
+
|
47 |
+
def get_encoded_dim(name):
|
48 |
+
if name not in T5_CONFIGS:
|
49 |
+
# avoids loading the model if we only want to get the dim
|
50 |
+
config = T5Config.from_pretrained(name)
|
51 |
+
T5_CONFIGS[name] = dict(config=config)
|
52 |
+
elif "config" in T5_CONFIGS[name]:
|
53 |
+
config = T5_CONFIGS[name]["config"]
|
54 |
+
elif "model" in T5_CONFIGS[name]:
|
55 |
+
config = T5_CONFIGS[name]["model"].config
|
56 |
+
else:
|
57 |
+
assert False
|
58 |
+
return config.d_model
|
59 |
+
|
60 |
+
# encoding text
|
61 |
+
|
62 |
+
def t5_tokenize(
|
63 |
+
texts: List[str],
|
64 |
+
name = DEFAULT_T5_NAME
|
65 |
+
):
|
66 |
+
t5, tokenizer = get_model_and_tokenizer(name)
|
67 |
+
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
t5 = t5.cuda()
|
70 |
+
|
71 |
+
device = next(t5.parameters()).device
|
72 |
+
|
73 |
+
encoded = tokenizer.batch_encode_plus(
|
74 |
+
texts,
|
75 |
+
return_tensors = "pt",
|
76 |
+
padding = 'longest',
|
77 |
+
max_length = MAX_LENGTH,
|
78 |
+
truncation = True
|
79 |
+
)
|
80 |
+
|
81 |
+
input_ids = encoded.input_ids.to(device)
|
82 |
+
attn_mask = encoded.attention_mask.to(device)
|
83 |
+
return input_ids, attn_mask
|
84 |
+
|
85 |
+
def t5_encode_tokenized_text(
|
86 |
+
token_ids,
|
87 |
+
attn_mask = None,
|
88 |
+
pad_id = None,
|
89 |
+
name = DEFAULT_T5_NAME
|
90 |
+
):
|
91 |
+
assert exists(attn_mask) or exists(pad_id)
|
92 |
+
t5, _ = get_model_and_tokenizer(name)
|
93 |
+
|
94 |
+
attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
|
95 |
+
|
96 |
+
t5.eval()
|
97 |
+
|
98 |
+
with torch.no_grad():
|
99 |
+
output = t5(input_ids = token_ids, attention_mask = attn_mask)
|
100 |
+
encoded_text = output.last_hidden_state.detach()
|
101 |
+
|
102 |
+
attn_mask = attn_mask.bool()
|
103 |
+
|
104 |
+
encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
|
105 |
+
return encoded_text
|
106 |
+
|
107 |
+
def t5_encode_text(
|
108 |
+
texts: List[str],
|
109 |
+
name = DEFAULT_T5_NAME,
|
110 |
+
return_attn_mask = False
|
111 |
+
):
|
112 |
+
token_ids, attn_mask = t5_tokenize(texts, name = name)
|
113 |
+
encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
|
114 |
+
|
115 |
+
if return_attn_mask:
|
116 |
+
attn_mask = attn_mask.bool()
|
117 |
+
return encoded_text, attn_mask
|
118 |
+
|
119 |
+
return encoded_text
|