steerers / collect_features /collect_i2p_flux.py
ryanjg's picture
init upload
33b542e verified
import os
import pandas as pd
import sys
import datetime
import json
import torch
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from SDLens.hooked_flux_pipeline import HookedFluxPipeline
import fire
import numpy as np
def to_kwargs(kwargs_to_save):
kwargs = kwargs_to_save.copy()
seed = kwargs['seed']
del kwargs['seed']
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(seed)
return kwargs
def main(save_path='I2P_FLUX/T5', start_at=0, finish_at=90000, chunk_size=1000):
blocks_to_save = ['text_encoder_2.encoder.block.22']
block = 'text_encoder.text_model.encoder.layers.22'
csv_filepaths = [
"datasets/i2p.csv"
] # Load CSV data
# Load and concatenate CSV data
data_frames = [pd.read_csv(filepath) for filepath in csv_filepaths]
data = pd.concat(data_frames, ignore_index=True)
prompts = data['prompt'].to_numpy()
try:
seeds = data['evaluation_seed'].to_numpy()
except:
try:
seeds = pd.read_csv['sd_seed'].to_numpy()
except:
seeds = [42 for i in range(len(prompts))]
try:
guidance_scales = data['evaluation_guidance'].to_numpy()
except:
try:
guidance_scales =data['sd_guidance_scale'].to_numpy()
except:
guidance_scales = [7.5 for i in range(len(prompts))]
# Initialize pipeline
pipe = HookedFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to('cuda')
pipe.set_progress_bar_config(disable=True)
# Create save path and metadata
ct = datetime.datetime.now()
save_path = os.path.join(save_path, str(ct))
os.makedirs(save_path, exist_ok=True)
data_tensors = []
metadata = []
chunk_idx = 0
chunk_start_idx = start_at
# Processing prompts
for num_document in tqdm(range(len(prompts)), desc="Processing Prompts", unit="prompt"):
if num_document < start_at:
continue
if num_document >= finish_at:
break
kwargs_to_save = {
'prompt': prompts[num_document],
'positions_to_cache': blocks_to_save,
'save_input': True,
'save_output': True,
'num_inference_steps': 1,
'guidance_scale': guidance_scales[num_document],
'seed': int(seeds[num_document]),
'output_type': 'pil',
}
kwargs = to_kwargs(kwargs_to_save)
output, cache = pipe.run_with_cache(**kwargs)
combined_output = cache['output'][blocks_to_save[0]].squeeze(1) # 512,4096
data_tensors.append(combined_output.cpu()) # Store output tensor
# Store metadata
metadata.append({
"sample_id": num_document,
"gen_args": kwargs_to_save
})
# Save chunk if it reaches the specified size
if len(data_tensors) >= chunk_size:
chunk_end_idx = chunk_start_idx + len(data_tensors) - 1
save_chunk(data_tensors, metadata, save_path, chunk_start_idx, chunk_end_idx, chunk_idx, block)
chunk_start_idx += len(data_tensors)
data_tensors = []
metadata = []
chunk_idx += 1
if data_tensors:
chunk_end_idx = num_document
save_chunk(data_tensors, metadata, save_path, chunk_start_idx, chunk_end_idx, chunk_idx, block)
print(f"Data saved in chunks to {save_path}")
def save_chunk(data_tensors, metadata, save_path, start_idx, end_idx, chunk_idx, block):
"""Save a chunk of tensors and metadata with index tracking."""
chunk_path = os.path.join(save_path, f'{block}_{start_idx:06d}_{end_idx:06d}.pt')
metadata_path = os.path.join(save_path, f'metadata_{start_idx:06d}_{end_idx:06d}.json')
# Stack tensors and save
torch.save(torch.cat(data_tensors), chunk_path)
# Save metadata as JSON
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=4, default=lambda o: int(o) if isinstance(o, (np.integer, torch.Tensor)) else o)
print(f"Saved chunk {chunk_idx}: {chunk_path}")
if __name__ == '__main__':
fire.Fire(main)