Spaces:
Build error
Build error
| import os, sys, time, re | |
| import torch | |
| from PIL import Image | |
| import hashlib | |
| from tqdm import tqdm | |
| import openai | |
| from utils.direction_utils import * | |
| p = "submodules/pix2pix-zero/src/utils" | |
| if p not in sys.path: | |
| sys.path.append(p) | |
| from diffusers import DDIMScheduler | |
| from edit_pipeline import EditingPipeline | |
| from ddim_inv import DDIMInversion | |
| from scheduler import DDIMInverseScheduler | |
| from lavis.models import load_model_and_preprocess | |
| from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device=DEVICE): | |
| with torch.no_grad(): | |
| l_embeddings = [] | |
| for sent in tqdm(l_sentences): | |
| text_inputs = tokenizer( | |
| sent, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] | |
| l_embeddings.append(prompt_embeds) | |
| return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) | |
| def launch_generate_sample(prompt, seed, negative_scale, num_ddim): | |
| os.makedirs("tmp", exist_ok=True) | |
| # do the editing | |
| edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
| edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) | |
| # set the random seed and sample the input noise map | |
| torch.cuda.manual_seed(int(seed)) if torch.cuda.is_available() else torch.manual_seed(int(seed)) | |
| z = torch.randn((1,4,64,64), device=DEVICE) | |
| z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest() | |
| z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt" | |
| torch.save(z, z_inv_fname) | |
| rec_pil = edit_pipe(prompt, | |
| num_inference_steps=num_ddim, x_in=z, | |
| only_sample=True, # this flag will only generate the sampled image, not the edited image | |
| guidance_scale=negative_scale, | |
| negative_prompt="" # use the empty string for the negative prompt | |
| ) | |
| # print(rec_pil) | |
| del edit_pipe | |
| torch.cuda.empty_cache() | |
| return rec_pil[0], z_inv_fname | |
| def clean_l_sentences(ls): | |
| s = [re.sub('\d', '', x) for x in ls] | |
| s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s] | |
| return s | |
| def gpt3_compute_word2sentences(task_type, word, num=100): | |
| l_sentences = [] | |
| if task_type=="object": | |
| template_prompt = f"Provide many captions for images containing {word}." | |
| elif task_type=="style": | |
| template_prompt = f"Provide many captions for images that are in the {word} style." | |
| while True: | |
| ret = openai.Completion.create( | |
| model="text-davinci-002", | |
| prompt=template_prompt, | |
| max_tokens=1000, | |
| temperature=1.0) | |
| raw_return = ret.choices[0].text | |
| for line in raw_return.split("\n"): | |
| line = line.strip() | |
| if len(line)>10: | |
| skip=False | |
| for subword in word.split(" "): | |
| if subword not in line: skip=True | |
| if not skip: l_sentences.append(line) | |
| else: | |
| l_sentences.append(line+f", {word}") | |
| time.sleep(0.05) | |
| print(len(l_sentences)) | |
| if len(l_sentences)>=num: | |
| break | |
| l_sentences = clean_l_sentences(l_sentences) | |
| return l_sentences | |
| def flant5xl_compute_word2sentences(word, num=100): | |
| text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters." | |
| l_sentences = [] | |
| tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl") | |
| model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) | |
| input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(DEVICE) | |
| input_length = input_ids.shape[1] | |
| while True: | |
| outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128) | |
| output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) | |
| for line in output: | |
| line = line.strip() | |
| skip=False | |
| for subword in word.split(" "): | |
| if subword not in line: skip=True | |
| if not skip: l_sentences.append(line) | |
| else: l_sentences.append(line+f", {word}") | |
| print(len(l_sentences)) | |
| if len(l_sentences)>=num: | |
| break | |
| l_sentences = clean_l_sentences(l_sentences) | |
| del model | |
| del tokenizer | |
| torch.cuda.empty_cache() | |
| return l_sentences | |
| def bloomz_compute_sentences(word, num=100): | |
| l_sentences = [] | |
| tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1") | |
| model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16) | |
| input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:" | |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(DEVICE) | |
| input_length = input_ids.shape[1] | |
| t = 0.95 | |
| eta = 1e-5 | |
| min_length = 15 | |
| while True: | |
| try: | |
| outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta) | |
| output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) | |
| except: | |
| continue | |
| for line in output: | |
| line = line.strip() | |
| skip=False | |
| for subword in word.split(" "): | |
| if subword not in line: skip=True | |
| if not skip: l_sentences.append(line) | |
| else: l_sentences.append(line+f", {word}") | |
| print(len(l_sentences)) | |
| if len(l_sentences)>=num: | |
| break | |
| l_sentences = clean_l_sentences(l_sentences) | |
| del model | |
| del tokenizer | |
| torch.cuda.empty_cache() | |
| return l_sentences | |
| def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences): | |
| if sent_type=="fixed-template": | |
| l_sentences = generate_image_prompts_with_templates(description) | |
| elif "GPT3" in sent_type: | |
| import openai | |
| openai.organization = org_key | |
| openai.api_key = api_key | |
| _=openai.Model.retrieve("text-davinci-002") | |
| l_sentences = gpt3_compute_word2sentences("object", description, num=1000) | |
| elif "flan-t5-xl" in sent_type: | |
| l_sentences = flant5xl_compute_word2sentences(description, num=1000) | |
| # save the sentences to file | |
| with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f: | |
| for line in l_sentences: | |
| f.write(line+"\n") | |
| elif "BLOOMZ-7B" in sent_type: | |
| l_sentences = bloomz_compute_sentences(description, num=1000) | |
| # save the sentences to file | |
| with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f: | |
| for line in l_sentences: | |
| f.write(line+"\n") | |
| elif sent_type=="custom sentences": | |
| l_sentences = l_custom_sentences.split("\n") | |
| print(f"length of new sentence is {len(l_sentences)}") | |
| pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
| emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device=DEVICE) | |
| del pipe | |
| torch.cuda.empty_cache() | |
| return emb | |
| def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest): | |
| d_name2desc = get_all_directions_names() | |
| d_desc2name = {v:k for k,v in d_name2desc.items()} | |
| os.makedirs("tmp", exist_ok=True) | |
| # generate custom direction first | |
| if src=="make your own!": | |
| outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt" | |
| if not os.path.exists(outf_name): | |
| src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src) | |
| torch.save(src_emb, outf_name) | |
| else: | |
| src_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) | |
| else: | |
| src_emb = get_emb(d_desc2name[src]) | |
| if dest=="make your own!": | |
| outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt" | |
| if not os.path.exists(outf_name): | |
| dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest) | |
| torch.save(dest_emb, outf_name) | |
| else: | |
| dest_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) | |
| else: | |
| dest_emb = get_emb(d_desc2name[dest]) | |
| text_dir = (dest_emb.to(DEVICE) - src_emb.to(DEVICE))*edit_mul | |
| if img_in_real is not None and img_in_synth is None: | |
| print("using real image") | |
| # resize the image so that the longer side is 512 | |
| width, height = img_in_real.size | |
| if width > height: scale_factor = 512 / width | |
| else: scale_factor = 512 / height | |
| new_size = (int(width * scale_factor), int(height * scale_factor)) | |
| img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS) | |
| hash = hashlib.sha256(img_in_real.tobytes()).hexdigest() | |
| # print(hash) | |
| inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt" | |
| caption_fname = f"tmp/{hash}_caption.txt" | |
| # make the caption if it hasn't been made before | |
| if not os.path.exists(caption_fname): | |
| # BLIP | |
| model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(DEVICE)) | |
| _image = vis_processors["eval"](img_in_real).unsqueeze(0).to(DEVICE) | |
| prompt_str = model_blip.generate({"image": _image})[0] | |
| del model_blip | |
| torch.cuda.empty_cache() | |
| with open(caption_fname, "w") as f: | |
| f.write(prompt_str) | |
| else: | |
| prompt_str = open(caption_fname, "r").read().strip() | |
| print(f"CAPTION: {prompt_str}") | |
| # do the inversion if it hasn't been done before | |
| if not os.path.exists(inv_fname): | |
| # inversion pipeline | |
| pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
| pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config) | |
| x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str, | |
| guidance_scale=1, num_inversion_steps=num_ddim, | |
| img=img_in_real, torch_dtype=torch.float32 ) | |
| x_inv = x_inv.detach() | |
| torch.save(x_inv, inv_fname) | |
| del pipe_inv | |
| torch.cuda.empty_cache() | |
| else: | |
| x_inv = torch.load(inv_fname, map_location=torch.device('cpu'), weights_only=True) | |
| # do the editing | |
| edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
| edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) | |
| _, edit_pil = edit_pipe(prompt_str, | |
| num_inference_steps=num_ddim, | |
| x_in=x_inv, | |
| edit_dir=text_dir, | |
| guidance_amount=xa_guidance, | |
| guidance_scale=5.0, | |
| negative_prompt=prompt_str # use the unedited prompt for the negative prompt | |
| ) | |
| del edit_pipe | |
| torch.cuda.empty_cache() | |
| return edit_pil[0] | |
| elif img_in_real is None and img_in_synth is not None: | |
| print("using synthetic image") | |
| x_inv = torch.load(fpath_z_gen, map_location=torch.device('cpu'), weights_only=True) | |
| pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| rec_pil, edit_pil = pipe(gen_prompt, | |
| num_inference_steps=num_ddim, | |
| x_in=x_inv, | |
| edit_dir=text_dir, | |
| guidance_amount=xa_guidance, | |
| guidance_scale=5, | |
| negative_prompt="" # use the empty string for the negative prompt | |
| ) | |
| del pipe | |
| torch.cuda.empty_cache() | |
| return edit_pil[0] | |
| else: | |
| raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}") | |
| if __name__=="__main__": | |
| print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100)) |