import os, pdb import argparse import numpy as np import torch import requests from PIL import Image from diffusers import DDIMScheduler from utils.edit_pipeline import EditingPipeline ## convert sentences to sentence embeddings def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"): with torch.no_grad(): l_embeddings = [] for sent in l_sentences: text_inputs = tokenizer( sent, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] l_embeddings.append(prompt_embeds) return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('--file_source_sentences', required=True) parser.add_argument('--file_target_sentences', required=True) parser.add_argument('--output_folder', required=True) parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') args = parser.parse_args() # load the model pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda") bname_src = os.path.basename(args.file_source_sentences).strip(".txt") outf_src = os.path.join(args.output_folder, bname_src+".pt") if os.path.exists(outf_src): print(f"Skipping source file {outf_src} as it already exists") else: with open(args.file_source_sentences, "r") as f: l_sents = [x.strip() for x in f.readlines()] mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") print(mean_emb.shape) torch.save(mean_emb, outf_src) bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt") outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt") if os.path.exists(outf_tgt): print(f"Skipping target file {outf_tgt} as it already exists") else: with open(args.file_target_sentences, "r") as f: l_sents = [x.strip() for x in f.readlines()] mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") print(mean_emb.shape) torch.save(mean_emb, outf_tgt)