espejelomar's picture
Update utils.py
e149a99
raw
history blame
3.94 kB
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
# from datasets import load_dataset
# from PIL import Image
import numpy as np
# import paddlehub as hub
# import random
# from PIL import ImageDraw,ImageFont
# import streamlit as st
# @st.experimental_singleton
# def load_bg_model():
# bg_model = hub.Module(name='U2NetP', directory='assets/models/')
# return bg_model
# bg_model = load_bg_model()
# def remove_bg(img):
# result = bg_model.Segmentation(
# images=[np.array(img)[:,:,::-1]],
# paths=None,
# batch_size=1,
# input_size=320,
# output_dir=None,
# visualization=False)
# output = result[0]
# mask=Image.fromarray(output['mask'])
# front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
# front.putalpha(mask)
# return front
# meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
# def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
# meme=meme_template.copy()
# approx_butterfly_center=(850,30)
# if remove_background:
# pigeon=remove_bg(pigeon)
# else:
# pigeon=Image.fromarray(pigeon).convert("RGBA")
# random_rotate=random.randint(-30,30)
# random_size=random.randint(150,200)
# pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
# meme.alpha_composite(pigeon, approx_butterfly_center)
# #ref: https://blog.lipsumarium.com/caption-memes-in-python/
# def drawTextWithOutline(text, x, y):
# draw.text((x-2, y-2), text,(0,0,0),font=font)
# draw.text((x+2, y-2), text,(0,0,0),font=font)
# draw.text((x+2, y+2), text,(0,0,0),font=font)
# draw.text((x-2, y+2), text,(0,0,0),font=font)
# draw.text((x, y), text, (255,255,255), font=font)
# if show_text:
# draw = ImageDraw.Draw(meme)
# font_size=52
# font = ImageFont.truetype("assets/impact.ttf", font_size)
# w, h = draw.textsize(text, font) # measure the size the text will take
# drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
# meme = meme.convert("RGB")
# return meme
# def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
# dataset=load_dataset(dataset_name)
# dataset=dataset.sort("sim_score")
# return dataset["train"]
# from transformers import BeitFeatureExtractor, BeitForImageClassification
# emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
# emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
# def embed(images):
# inputs = emb_feature_extractor(images=images, return_tensors="pt")
# outputs = emb_model(**inputs,output_hidden_states= True)
# last_hidden=outputs.hidden_states[-1]
# pooler=emb_model.base_model.pooler
# final_emb=pooler(last_hidden).detach().numpy()
# return final_emb
# def build_index():
# dataset=get_train_data()
# ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
# ds_with_embeddings.add_faiss_index(column='beit_embeddings')
# ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
# def get_dataset():
# dataset=get_train_data()
# dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
# return dataset
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
gan = LightweightGAN.from_pretrained(model_name,version=model_version)
gan.eval()
return gan
def generate(gan,batch_size=1):
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
return ims
# def interpolate():
# pass