butterfly-gan / demo.py
Ceyda Cinarel
add nearest neighbor
9fbe234
raw
history blame
No virus
1.92 kB
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
from datasets import load_dataset
def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropped",data_limit=1000):
dataset=load_dataset(dataset_name)
dataset=dataset.sort("sim_score")
score_thresh = dataset["train"][data_limit]['sim_score']
dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
dataset = dataset.map(lambda x: {'image' : x['image'].convert("RGB")})
return dataset["train"]
from transformers import BeitFeatureExtractor, BeitForImageClassification
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
def embed(images):
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = model(**inputs,output_hidden_states= True)
last_hidden=outputs.hidden_states[-1]
pooler=model.base_model.pooler
final_emb=pooler(last_hidden).detach().numpy()
return final_emb
def build_index():
dataset=get_train_data()
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
ds_with_embeddings.add_faiss_index(column='beit_embeddings')
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
def get_dataset():
dataset=get_train_data()
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
return dataset
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
gan = LightweightGAN.from_pretrained(model_name)
gan.eval()
return gan
def generate(gan,batch_size=1):
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
ims = ims.permute(0,2,3,1).detach().cpu().numpy()
return ims
def interpolate():
pass