File size: 4,471 Bytes
33b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import pandas as pd
import sys
import datetime
import json
import torch
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
import fire
from itertools import islice
import numpy as np
def to_kwargs(kwargs_to_save):
kwargs = kwargs_to_save.copy()
seed = kwargs['seed']
del kwargs['seed']
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(seed)
return kwargs
def main(save_path='I2P_SDXL', start_at=0, finish_at=90000, chunk_size=1000):
blocks_to_save = ['text_encoder.text_model.encoder.layers.10', 'text_encoder_2.text_model.encoder.layers.28']
block = 'text_encoder.text_model.encoder.layers.10.28'
csv_filepaths = [
"datasets/i2p.csv"
] # Load CSV data
# Load and concatenate CSV data
data_frames = [pd.read_csv(filepath) for filepath in csv_filepaths]
data = pd.concat(data_frames, ignore_index=True)
prompts = data['prompt'].to_numpy()
try:
seeds = data['evaluation_seed'].to_numpy()
except:
try:
seeds = pd.read_csv['sd_seed'].to_numpy()
except:
seeds = [42 for i in range(len(prompts))]
try:
guidance_scales = data['evaluation_guidance'].to_numpy()
except:
try:
guidance_scales =data['sd_guidance_scale'].to_numpy()
except:
guidance_scales = [7.5 for i in range(len(prompts))]
# Initialize pipeline
pipe = HookedStableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.to('cuda')
pipe.set_progress_bar_config(disable=True)
# Create save path and metadata
ct = datetime.datetime.now()
save_path = os.path.join(save_path, str(ct))
os.makedirs(save_path, exist_ok=True)
data_tensors = []
metadata = []
chunk_idx = 0
chunk_start_idx = start_at
# Processing prompts
for num_document in tqdm(range(len(prompts)), desc="Processing Prompts", unit="prompt"):
if num_document < start_at:
continue
if num_document >= finish_at:
break
kwargs_to_save = {
'prompt': prompts[num_document],
'positions_to_cache': blocks_to_save,
'save_input': True,
'save_output': True,
'num_inference_steps': 1,
'guidance_scale': guidance_scales[num_document],
'seed': int(seeds[num_document]),
'output_type': 'pil',
}
kwargs = to_kwargs(kwargs_to_save)
output, cache = pipe.run_with_cache(**kwargs_to_save)
combined_output = torch.cat([cache['output'][blocks_to_save[0]], cache['output'][blocks_to_save[1]]], dim=-1).squeeze(1)
data_tensors.append(combined_output.cpu()) # Store output tensor
# Store metadata
metadata.append({
"sample_id": num_document,
"gen_args": kwargs_to_save
})
# Save chunk if it reaches the specified size
if len(data_tensors) >= chunk_size:
chunk_end_idx = chunk_start_idx + len(data_tensors) - 1
save_chunk(data_tensors, metadata, save_path, chunk_start_idx, chunk_end_idx, chunk_idx, block)
chunk_start_idx += len(data_tensors)
data_tensors = []
metadata = []
chunk_idx += 1
if data_tensors:
chunk_end_idx = num_document
save_chunk(data_tensors, metadata, save_path, chunk_start_idx, chunk_end_idx, chunk_idx, block)
print(f"Data saved in chunks to {save_path}")
def save_chunk(data_tensors, metadata, save_path, start_idx, end_idx, chunk_idx, block):
"""Save a chunk of tensors and metadata with index tracking."""
chunk_path = os.path.join(save_path, f'{block}_{start_idx:06d}_{end_idx:06d}.pt')
metadata_path = os.path.join(save_path, f'metadata_{start_idx:06d}_{end_idx:06d}.json')
# Stack tensors and save
torch.save(torch.cat(data_tensors), chunk_path)
# Save metadata as JSON
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=4, default=lambda o: int(o) if isinstance(o, (np.integer, torch.Tensor)) else o)
print(f"Saved chunk {chunk_idx}: {chunk_path}")
if __name__ == '__main__':
fire.Fire(main)
|