import torch from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN from datasets import load_dataset def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropped",data_limit=1000): dataset=load_dataset(dataset_name) dataset=dataset.sort("sim_score") score_thresh = dataset["train"][data_limit]['sim_score'] dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh) dataset = dataset.map(lambda x: {'image' : x['image'].convert("RGB")}) return dataset["train"] from transformers import BeitFeatureExtractor, BeitForImageClassification feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') def embed(images): inputs = feature_extractor(images=images, return_tensors="pt") outputs = model(**inputs,output_hidden_states= True) last_hidden=outputs.hidden_states[-1] pooler=model.base_model.pooler final_emb=pooler(last_hidden).detach().numpy() return final_emb def build_index(): dataset=get_train_data() ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20) ds_with_embeddings.add_faiss_index(column='beit_embeddings') ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss') def get_dataset(): dataset=get_train_data() dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss') return dataset def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'): gan = LightweightGAN.from_pretrained(model_name) gan.eval() return gan def generate(gan,batch_size=1): with torch.no_grad(): ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.) ims = ims.permute(0,2,3,1).detach().cpu().numpy() return ims def interpolate(): pass