yue-here commited on
Commit
5edc0a2
1 Parent(s): d5851a1

first commit

Browse files
Files changed (3) hide show
  1. app.py +17 -0
  2. glyffuser_utils.py +174 -0
  3. t5.py +119 -0
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from glyffuser_utils import GlyffuserPipeline
3
+
4
+ pipeline = GlyffuserPipeline.from_pretrained("yuewu/glyffuser")
5
+
6
+ def infer(text):
7
+ generated_images = pipeline(
8
+ texts,
9
+ batch_size=1, # Generate one image at a time for each step
10
+ # generator=torch.Generator(device='cuda').manual_seed(config.seed), # Generator can be on GPU here
11
+ num_inference_steps=50
12
+ ).images
13
+
14
+ return generated_images[0]
15
+
16
+ demo = gr.Interface(fn=infer, inputs="text", outputs="image")
17
+ demo.launch()
glyffuser_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms as T
5
+ import t5
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ from PIL import Image
9
+
10
+ from datasets import load_dataset
11
+
12
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
13
+ from typing import List, Optional, Tuple, Union
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+
16
+
17
+ # Collator adjusted for local dataset
18
+ class Collator:
19
+ def __init__(self, image_size, text_label, image_label, name, channels):
20
+ self.text_label = text_label
21
+ self.image_label = image_label
22
+ self.name = name
23
+ self.channels = channels
24
+ self.transform = T.Compose([
25
+ T.Resize((image_size, image_size)),
26
+ T.ToTensor(),
27
+ ])
28
+
29
+ def __call__(self, batch):
30
+ texts = []
31
+ masks = []
32
+ images = []
33
+ for item in batch:
34
+ try:
35
+ # Load image from local file
36
+ image_path = 'data/'+item[self.image_label] # Assuming this is a path to the image file
37
+ with Image.open(image_path) as img:
38
+ image = self.transform(img.convert(self.channels))
39
+ except Exception as e:
40
+ print(f"Failed to process image {image_path}: {e}")
41
+ continue
42
+
43
+ # Encode the text
44
+ text, mask = t5.t5_encode_text(
45
+ [item[self.text_label]],
46
+ name=self.name,
47
+ return_attn_mask=True
48
+ )
49
+ texts.append(torch.squeeze(text))
50
+ masks.append(torch.squeeze(mask))
51
+ images.append(image)
52
+
53
+ if len(texts) == 0:
54
+ return None
55
+
56
+ # Are these strictly necessary?
57
+ texts = pad_sequence(texts, True)
58
+ masks = pad_sequence(masks, True)
59
+
60
+ newbatch = []
61
+ for i in range(len(texts)):
62
+ newbatch.append((images[i], texts[i], masks[i]))
63
+
64
+ return torch.utils.data.dataloader.default_collate(newbatch)
65
+
66
+
67
+ class GlyffuserPipeline(DiffusionPipeline):
68
+ r'''
69
+ Pipeline for text-to-image generation from the glyffuser model
70
+
71
+ Parameters:
72
+ unet (['UNet2DConditionModel'])
73
+ scheduler (['SchedulerMixin'])
74
+ text_encoder (['TextEncoder']) - T5 small
75
+ '''
76
+ def __init__(self, unet, scheduler):
77
+ super().__init__()
78
+ self.register_modules(
79
+ unet=unet,
80
+ scheduler=scheduler,
81
+ )
82
+
83
+ @torch.no_grad()
84
+ def __call__(
85
+ self,
86
+ texts: List[str],
87
+ text_encoder: str = "google-t5/t5-small",
88
+ batch_size: int = 1,
89
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
90
+ num_inference_steps: int = 1000,
91
+ output_type: Optional[str] = "pil",
92
+ return_dict: bool = True,
93
+ ) -> Union[ImagePipelineOutput, Tuple]:
94
+ '''
95
+ Docstring
96
+ '''
97
+ # Get text embeddings
98
+ # Encode the text
99
+ # text_embeddings = []
100
+ # for text in texts:
101
+ # embedding = t5.t5_encode_text(text, name=text_encoder)
102
+ # text_embeddings.append(torch.squeeze(embedding))
103
+ # text_embeddings = pad_sequence(text_embeddings, True)
104
+
105
+ batch_size = len(texts)
106
+
107
+ text_embeddings, masks = t5.t5_encode_text(texts, name=text_encoder, return_attn_mask=True)
108
+
109
+ # Sample gaussian noise to begin loop
110
+ if isinstance(self.unet.config.sample_size, int):
111
+ image_shape = (
112
+ batch_size,
113
+ self.unet.config.in_channels,
114
+ self.unet.config.sample_size,
115
+ self.unet.config.sample_size,
116
+ )
117
+ else:
118
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
119
+
120
+
121
+ # if self.device.type == "mps": # MPS is apple silicon
122
+ # # randn does not work reproducibly on mps
123
+ # image = randn_tensor(image_shape, generator=generator)
124
+ # image = image.to(self.device)
125
+ # else:
126
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
127
+
128
+ # set step values
129
+ self.scheduler.set_timesteps(num_inference_steps)
130
+
131
+ for t in self.progress_bar(self.scheduler.timesteps):
132
+ # 1. predict noise model_output
133
+ model_output = self.unet(
134
+ image,
135
+ t,
136
+ encoder_hidden_states=text_embeddings, # Add text encoding input
137
+ encoder_attention_mask=masks, # Add attention mask
138
+ return_dict=False
139
+ )[0] # <-- sample is an attribute of the BaseOutClass of type torch.FloatTensor
140
+
141
+ # 2. compute previous image: x_t -> x_t-1
142
+ image = self.scheduler.step(model_output, t, image, generator=generator, return_dict=False)[0]
143
+
144
+ image = (image / 2 + 0.5).clamp(0, 1)
145
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
146
+ if output_type == "pil":
147
+ image = self.numpy_to_pil(image)
148
+
149
+ if not return_dict:
150
+ return (image,)
151
+
152
+ return ImagePipelineOutput(images=image)
153
+
154
+ def make_grid(images, rows, cols):
155
+ w, h = images[0].size
156
+ grid = Image.new('RGB', size=(cols*w, rows*h))
157
+ for i, image in enumerate(images):
158
+ grid.paste(image, box=(i%cols*w, i//cols*h))
159
+ return grid
160
+
161
+ def evaluate(config, epoch, texts, pipeline):
162
+ images = pipeline(
163
+ texts,
164
+ batch_size = config.eval_batch_size,
165
+ generator=torch.Generator(device='cpu').manual_seed(config.seed), # Generator must be on CPU for sampling during training
166
+ ).images
167
+
168
+ # Make a grid out of the images
169
+ image_grid = make_grid(images, rows=4, cols=4)
170
+
171
+ # Save the images
172
+ test_dir = os.path.join(config.output_dir, "samples")
173
+ os.makedirs(test_dir, exist_ok=True)
174
+ image_grid.save(f"{test_dir}/{epoch:04d}.png")
t5.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from typing import List
4
+ from transformers import T5Tokenizer, T5EncoderModel, T5Config
5
+ from einops import rearrange
6
+
7
+ transformers.logging.set_verbosity_error()
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def default(val, d):
13
+ if exists(val):
14
+ return val
15
+ return d() if callable(d) else d
16
+
17
+ # config
18
+
19
+ MAX_LENGTH = 256
20
+
21
+ DEFAULT_T5_NAME = 'google/t5-v1_1-base'
22
+
23
+ T5_CONFIGS = {}
24
+
25
+ # singleton globals
26
+
27
+ def get_tokenizer(name):
28
+ tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
29
+ return tokenizer
30
+
31
+ def get_model(name):
32
+ model = T5EncoderModel.from_pretrained(name)
33
+ return model
34
+
35
+ def get_model_and_tokenizer(name):
36
+ global T5_CONFIGS
37
+
38
+ if name not in T5_CONFIGS:
39
+ T5_CONFIGS[name] = dict()
40
+ if "model" not in T5_CONFIGS[name]:
41
+ T5_CONFIGS[name]["model"] = get_model(name)
42
+ if "tokenizer" not in T5_CONFIGS[name]:
43
+ T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
44
+
45
+ return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
46
+
47
+ def get_encoded_dim(name):
48
+ if name not in T5_CONFIGS:
49
+ # avoids loading the model if we only want to get the dim
50
+ config = T5Config.from_pretrained(name)
51
+ T5_CONFIGS[name] = dict(config=config)
52
+ elif "config" in T5_CONFIGS[name]:
53
+ config = T5_CONFIGS[name]["config"]
54
+ elif "model" in T5_CONFIGS[name]:
55
+ config = T5_CONFIGS[name]["model"].config
56
+ else:
57
+ assert False
58
+ return config.d_model
59
+
60
+ # encoding text
61
+
62
+ def t5_tokenize(
63
+ texts: List[str],
64
+ name = DEFAULT_T5_NAME
65
+ ):
66
+ t5, tokenizer = get_model_and_tokenizer(name)
67
+
68
+ if torch.cuda.is_available():
69
+ t5 = t5.cuda()
70
+
71
+ device = next(t5.parameters()).device
72
+
73
+ encoded = tokenizer.batch_encode_plus(
74
+ texts,
75
+ return_tensors = "pt",
76
+ padding = 'longest',
77
+ max_length = MAX_LENGTH,
78
+ truncation = True
79
+ )
80
+
81
+ input_ids = encoded.input_ids.to(device)
82
+ attn_mask = encoded.attention_mask.to(device)
83
+ return input_ids, attn_mask
84
+
85
+ def t5_encode_tokenized_text(
86
+ token_ids,
87
+ attn_mask = None,
88
+ pad_id = None,
89
+ name = DEFAULT_T5_NAME
90
+ ):
91
+ assert exists(attn_mask) or exists(pad_id)
92
+ t5, _ = get_model_and_tokenizer(name)
93
+
94
+ attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
95
+
96
+ t5.eval()
97
+
98
+ with torch.no_grad():
99
+ output = t5(input_ids = token_ids, attention_mask = attn_mask)
100
+ encoded_text = output.last_hidden_state.detach()
101
+
102
+ attn_mask = attn_mask.bool()
103
+
104
+ encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
105
+ return encoded_text
106
+
107
+ def t5_encode_text(
108
+ texts: List[str],
109
+ name = DEFAULT_T5_NAME,
110
+ return_attn_mask = False
111
+ ):
112
+ token_ids, attn_mask = t5_tokenize(texts, name = name)
113
+ encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
114
+
115
+ if return_attn_mask:
116
+ attn_mask = attn_mask.bool()
117
+ return encoded_text, attn_mask
118
+
119
+ return encoded_text