pix2pix-zero-demo-CPU / utils /generate_synthetic.py
jchwenger's picture
Upload 351 files (#2)
d9272c6 verified
import os, sys, time, re
import torch
from PIL import Image
import hashlib
from tqdm import tqdm
import openai
from utils.direction_utils import *
p = "submodules/pix2pix-zero/src/utils"
if p not in sys.path:
sys.path.append(p)
from diffusers import DDIMScheduler
from edit_pipeline import EditingPipeline
from ddim_inv import DDIMInversion
from scheduler import DDIMInverseScheduler
from lavis.models import load_model_and_preprocess
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device=DEVICE):
with torch.no_grad():
l_embeddings = []
for sent in tqdm(l_sentences):
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
l_embeddings.append(prompt_embeds)
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
def launch_generate_sample(prompt, seed, negative_scale, num_ddim):
os.makedirs("tmp", exist_ok=True)
# do the editing
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE)
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
# set the random seed and sample the input noise map
torch.cuda.manual_seed(int(seed)) if torch.cuda.is_available() else torch.manual_seed(int(seed))
z = torch.randn((1,4,64,64), device=DEVICE)
z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest()
z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt"
torch.save(z, z_inv_fname)
rec_pil = edit_pipe(prompt,
num_inference_steps=num_ddim, x_in=z,
only_sample=True, # this flag will only generate the sampled image, not the edited image
guidance_scale=negative_scale,
negative_prompt="" # use the empty string for the negative prompt
)
# print(rec_pil)
del edit_pipe
torch.cuda.empty_cache()
return rec_pil[0], z_inv_fname
def clean_l_sentences(ls):
s = [re.sub('\d', '', x) for x in ls]
s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s]
return s
def gpt3_compute_word2sentences(task_type, word, num=100):
l_sentences = []
if task_type=="object":
template_prompt = f"Provide many captions for images containing {word}."
elif task_type=="style":
template_prompt = f"Provide many captions for images that are in the {word} style."
while True:
ret = openai.Completion.create(
model="text-davinci-002",
prompt=template_prompt,
max_tokens=1000,
temperature=1.0)
raw_return = ret.choices[0].text
for line in raw_return.split("\n"):
line = line.strip()
if len(line)>10:
skip=False
for subword in word.split(" "):
if subword not in line: skip=True
if not skip: l_sentences.append(line)
else:
l_sentences.append(line+f", {word}")
time.sleep(0.05)
print(len(l_sentences))
if len(l_sentences)>=num:
break
l_sentences = clean_l_sentences(l_sentences)
return l_sentences
def flant5xl_compute_word2sentences(word, num=100):
text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters."
l_sentences = []
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(DEVICE)
input_length = input_ids.shape[1]
while True:
outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128)
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
for line in output:
line = line.strip()
skip=False
for subword in word.split(" "):
if subword not in line: skip=True
if not skip: l_sentences.append(line)
else: l_sentences.append(line+f", {word}")
print(len(l_sentences))
if len(l_sentences)>=num:
break
l_sentences = clean_l_sentences(l_sentences)
del model
del tokenizer
torch.cuda.empty_cache()
return l_sentences
def bloomz_compute_sentences(word, num=100):
l_sentences = []
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16)
input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(DEVICE)
input_length = input_ids.shape[1]
t = 0.95
eta = 1e-5
min_length = 15
while True:
try:
outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta)
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
except:
continue
for line in output:
line = line.strip()
skip=False
for subword in word.split(" "):
if subword not in line: skip=True
if not skip: l_sentences.append(line)
else: l_sentences.append(line+f", {word}")
print(len(l_sentences))
if len(l_sentences)>=num:
break
l_sentences = clean_l_sentences(l_sentences)
del model
del tokenizer
torch.cuda.empty_cache()
return l_sentences
def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences):
if sent_type=="fixed-template":
l_sentences = generate_image_prompts_with_templates(description)
elif "GPT3" in sent_type:
import openai
openai.organization = org_key
openai.api_key = api_key
_=openai.Model.retrieve("text-davinci-002")
l_sentences = gpt3_compute_word2sentences("object", description, num=1000)
elif "flan-t5-xl" in sent_type:
l_sentences = flant5xl_compute_word2sentences(description, num=1000)
# save the sentences to file
with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f:
for line in l_sentences:
f.write(line+"\n")
elif "BLOOMZ-7B" in sent_type:
l_sentences = bloomz_compute_sentences(description, num=1000)
# save the sentences to file
with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f:
for line in l_sentences:
f.write(line+"\n")
elif sent_type=="custom sentences":
l_sentences = l_custom_sentences.split("\n")
print(f"length of new sentence is {len(l_sentences)}")
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE)
emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device=DEVICE)
del pipe
torch.cuda.empty_cache()
return emb
def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest):
d_name2desc = get_all_directions_names()
d_desc2name = {v:k for k,v in d_name2desc.items()}
os.makedirs("tmp", exist_ok=True)
# generate custom direction first
if src=="make your own!":
outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt"
if not os.path.exists(outf_name):
src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src)
torch.save(src_emb, outf_name)
else:
src_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True)
else:
src_emb = get_emb(d_desc2name[src])
if dest=="make your own!":
outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt"
if not os.path.exists(outf_name):
dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest)
torch.save(dest_emb, outf_name)
else:
dest_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True)
else:
dest_emb = get_emb(d_desc2name[dest])
text_dir = (dest_emb.to(DEVICE) - src_emb.to(DEVICE))*edit_mul
if img_in_real is not None and img_in_synth is None:
print("using real image")
# resize the image so that the longer side is 512
width, height = img_in_real.size
if width > height: scale_factor = 512 / width
else: scale_factor = 512 / height
new_size = (int(width * scale_factor), int(height * scale_factor))
img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS)
hash = hashlib.sha256(img_in_real.tobytes()).hexdigest()
# print(hash)
inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt"
caption_fname = f"tmp/{hash}_caption.txt"
# make the caption if it hasn't been made before
if not os.path.exists(caption_fname):
# BLIP
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(DEVICE))
_image = vis_processors["eval"](img_in_real).unsqueeze(0).to(DEVICE)
prompt_str = model_blip.generate({"image": _image})[0]
del model_blip
torch.cuda.empty_cache()
with open(caption_fname, "w") as f:
f.write(prompt_str)
else:
prompt_str = open(caption_fname, "r").read().strip()
print(f"CAPTION: {prompt_str}")
# do the inversion if it hasn't been done before
if not os.path.exists(inv_fname):
# inversion pipeline
pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE)
pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config)
x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str,
guidance_scale=1, num_inversion_steps=num_ddim,
img=img_in_real, torch_dtype=torch.float32 )
x_inv = x_inv.detach()
torch.save(x_inv, inv_fname)
del pipe_inv
torch.cuda.empty_cache()
else:
x_inv = torch.load(inv_fname, map_location=torch.device('cpu'), weights_only=True)
# do the editing
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE)
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
_, edit_pil = edit_pipe(prompt_str,
num_inference_steps=num_ddim,
x_in=x_inv,
edit_dir=text_dir,
guidance_amount=xa_guidance,
guidance_scale=5.0,
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
)
del edit_pipe
torch.cuda.empty_cache()
return edit_pil[0]
elif img_in_real is None and img_in_synth is not None:
print("using synthetic image")
x_inv = torch.load(fpath_z_gen, map_location=torch.device('cpu'), weights_only=True)
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
rec_pil, edit_pil = pipe(gen_prompt,
num_inference_steps=num_ddim,
x_in=x_inv,
edit_dir=text_dir,
guidance_amount=xa_guidance,
guidance_scale=5,
negative_prompt="" # use the empty string for the negative prompt
)
del pipe
torch.cuda.empty_cache()
return edit_pil[0]
else:
raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}")
if __name__=="__main__":
print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))