t2i-cmp / run.py
Zeqiang-Lai
add all files
021120f
raw
history blame contribute delete
No virus
1.66 kB
import os
import csv
import torch
from diffusers import AutoPipelineForText2Image
def load_prompts(path):
if os.path.basename(path) == 'ViLG-300.csv':
def csv_to_dict(file_path):
result_dict = {}
with open(file_path, 'r', encoding='utf-8') as csv_file:
csv_reader = csv.DictReader(csv_file, delimiter=',')
for row in csv_reader:
prompt = row['\ufeffPrompt']
text = row['文本']
category = row['类别']
source = row['来源']
result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source}
return result_dict
data = csv_to_dict(path).keys()
else:
return NotImplementedError
return data
def main(
model_id="runwayml/stable-diffusion-v1-5",
prompt_path="assets/ViLG-300.csv",
save_path=None,
dtype='fp16',
variant=None,
):
if save_path is None:
save_path = os.path.join('saved', model_id.replace('/', '_'))
os.makedirs(save_path, exist_ok=True)
prompts = load_prompts(prompt_path)
pipeline = AutoPipelineForText2Image.from_pretrained(
model_id,
variant=variant,
torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16
)
pipeline.to(device='cuda')
pipeline.safety_checker = None
for i, prompt in enumerate(prompts):
print(f'{i}|{len(prompts)}: {prompt}')
image = pipeline(prompt).images[0]
image.save(os.path.join(save_path, f'{i}.jpg'))
if __name__ == '__main__':
import fire
fire.Fire(main)