huggingnft / app.py
drdata's picture
Update app.py
c3dea8f
import json
import torch
from huggingnft.lightweight_gan.train import timestamped_filename
from streamlit_option_menu import option_menu
from huggingface_hub import hf_hub_download, file_download
from PIL import Image
from huggingface_hub.hf_api import HfApi
import streamlit as st
from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer
from accelerate import Accelerator
from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet
from torchvision import transforms as T
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
from torchvision.utils import make_grid
import requests
hfapi = HfApi()
model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
# streamlit-option-menu
# st.set_page_config(page_title="Streamlit App Gallery", page_icon="", layout="wide")
# sysmenu = '''
# <style>
# #MainMenu {visibility:hidden;}
# footer {visibility:hidden;}
# '''
# st.markdown(sysmenu,unsafe_allow_html=True)
# # Add a logo (optional) in the sidebar
# logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png')
# profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png')
ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name."
CONTACT_TEXT = """
_Built by Data ❤️
"""
GENERATE_IMAGE_TEXT = "Generate NFT by selecting existing model based on OpenSea collection. You can create new model or imporve existing in few clicks."
INTERPOLATION_TEXT = "Generate interpolation between two NFTs by selecting existing model based on OpenSea collection. You can create new model or imporve existing in few clicks."
COLLECTION2COLLECTION_TEXT = "Generate first NFT with existing model and transform it to another collection by selecting existing model based on OpenSea collections. You can create new model or imporve existing in few clicks."
TRAIN_TEXT = "> If you think that the results of the model are not good enough and they can be improved, you can train the model more in a few clicks. If you notice that the model is overtrained, then you can easily return to the best version. "
STOPWORDS = ["-old"]
COLLECTION2COLLECTION_KEYS = ["__2__"]
def load_lightweight_model(model_name):
file_path = file_download.hf_hub_download(
repo_id=model_name,
filename="config.json"
)
config = json.loads(open(file_path).read())
organization_name, name = model_name.split("/")
model = Trainer(**config, organization_name=organization_name, name=name)
model.load(use_cpu=True)
model.accelerator = Accelerator()
return model
def clean_models(model_names, stopwords):
cleaned_model_names = []
for model_name in model_names:
clear = True
for stopword in stopwords:
if stopword in model_name:
clear = False
break
if clear:
cleaned_model_names.append(model_name)
return cleaned_model_names
def get_concat_h(im1, im2):
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
model_names = clean_models(model_names, STOPWORDS)
with st.sidebar:
choose = option_menu("Hugging NFT",
["About", "Generate image", "Interpolation", "Collection2Collection"],
icons=['house', 'camera fill', 'bi bi-youtube', 'book'],
menu_icon="app-indicator", default_index=0,
styles={
# "container": {"padding": "5!important", "background-color": "#fafafa", },
"container": {"border-radius": ".0rem"},
# "icon": {"color": "orange", "font-size": "25px"},
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
# "--hover-color": "#eee"},
# "nav-link-selected": {"background-color": "#02ab21"},
}
)
if choose == "About":
README = requests.get("https://raw.githubusercontent.com/dr-data/huggingnft/main/README.md").text
README = str(README).replace('width="1200"','width="700"')
# st.title(choose)
st.markdown(README, unsafe_allow_html=True)
if choose == "Contact":
st.title(choose)
st.markdown(CONTACT_TEXT)
if choose == "Generate image":
st.title(choose)
st.markdown(GENERATE_IMAGE_TEXT)
model_name = st.selectbox(
'Choose model:',
clean_models(model_names, COLLECTION2COLLECTION_KEYS)
)
generation_type = st.selectbox(
'Select generation type:',
["default", "ema"]
)
nrows = st.number_input("Number of rows:",
min_value=1,
max_value=10,
step=1,
value=8,
)
generate_image_button = st.button("Generate")
if generate_image_button:
with st.spinner(text=f"Downloading selected model..."):
model = load_lightweight_model(f"huggingnft/{model_name}")
with st.spinner(text=f"Generating..."):
image = model.generate_app(
num=timestamped_filename(),
nrow=nrows,
checkpoint=-1,
types=generation_type
)[0]
st.markdown(TRAIN_TEXT)
st.image(
image
)
if choose == "Interpolation":
st.title(choose)
st.markdown(INTERPOLATION_TEXT)
model_name = st.selectbox(
'Choose model:',
clean_models(model_names, COLLECTION2COLLECTION_KEYS)
)
nrows = st.number_input("Number of rows:",
min_value=1,
max_value=4,
step=1,
value=1,
)
num_steps = st.number_input("Number of steps:",
min_value=1,
max_value=200,
step=1,
value=100,
)
generate_image_button = st.button("Generate")
if generate_image_button:
with st.spinner(text=f"Downloading selected model..."):
model = load_lightweight_model(f"huggingnft/{model_name}")
my_bar = st.progress(0)
result = model.generate_interpolation(
num=timestamped_filename(),
num_image_tiles=nrows,
num_steps=num_steps,
save_frames=False,
progress_bar=my_bar
)
my_bar.empty()
st.markdown(TRAIN_TEXT)
st.image(
result
)
if choose == "Collection2Collection":
st.title(choose)
st.markdown(COLLECTION2COLLECTION_TEXT)
model_name = st.selectbox(
'Choose model:',
set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS))
)
nrows = st.number_input("Number of images to generate:",
min_value=1,
max_value=10,
step=1,
value=1,
)
generate_image_button = st.button("Generate")
if generate_image_button:
n_channels = 3
image_size = 256
input_shape = (image_size, image_size)
transform = Compose([
T.ToPILImage(),
T.Resize(input_shape),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
with st.spinner(text=f"Downloading selected model..."):
translator = GeneratorResNet.from_pretrained(f'huggingnft/{model_name}',
input_shape=(n_channels, image_size, image_size),
num_residual_blocks=9)
z = torch.randn(nrows, 100, 1, 1)
with st.spinner(text=f"Downloading selected model..."):
model = load_lightweight_model(f"huggingnft/{model_name.split('__2__')[0]}")
with st.spinner(text=f"Generating input images..."):
punks = model.generate_app(
num=timestamped_filename(),
nrow=nrows,
checkpoint=-1,
types="default"
)[1]
pipe_transform = T.Resize((256, 256))
input = pipe_transform(punks)
with st.spinner(text=f"Generating output images..."):
output = translator(input)
out_img = make_grid(output,
nrow=4, normalize=True)
# out_img = make_grid(punks,
# nrow=8, normalize=True)
out_transform = Compose([
T.ToPILImage()
])
results = []
for out_punk, out_ape in zip(input, output):
results.append(
get_concat_h(out_transform(make_grid(out_punk, nrow=1, normalize=True)), out_transform(make_grid(out_ape, nrow=1, normalize=True)))
)
st.markdown(TRAIN_TEXT)
for result in results:
st.image(result)