huggingnft / app.py
AlekseyKorshuk's picture
Update app.py
e774b65
raw history blame
No virus
6.78 kB
import subprocess
from pathlib import Path
import einops
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision.utils import save_image
from huggingface_hub.hf_api import HfApi
import streamlit as st
hfapi = HfApi()
class Generator(nn.Module):
def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
super(Generator, self).__init__()
self.model = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(hidden_size * 8),
nn.ReLU(True),
# state size. (hidden_size*8) x 4 x 4
nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size * 4),
nn.ReLU(True),
# state size. (hidden_size*4) x 8 x 8
nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size * 2),
nn.ReLU(True),
# state size. (hidden_size*2) x 16 x 16
nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size),
nn.ReLU(True),
# state size. (hidden_size) x 32 x 32
nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (num_channels) x 64 x 64
)
def forward(self, noise):
pixel_values = self.model(noise)
return pixel_values
@torch.no_grad()
def interpolate(model, save_dir='./lerp/', frames=100, rows=8, cols=8):
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
z1 = torch.randn(rows * cols, 100, 1, 1)
z2 = torch.randn(rows * cols, 100, 1, 1)
zs = []
for i in range(frames):
alpha = i / frames
z = (1 - alpha) * z1 + alpha * z2
zs.append(z)
zs += zs[::-1] # also go in reverse order to complete loop
frames = []
for i, z in enumerate(zs):
imgs = model(z)
save_image(imgs, save_dir / f"{i:03}.png", normalize=True)
img = Image.open(save_dir / f"{i:03}.png").convert('RGBA')
img.putalpha(255)
frames.append(img)
img.save(save_dir / f"{i:03}.png")
frames[0].save("out.gif", format="GIF", append_images=frames,
save_all=True, duration=100, loop=1)
def predict(model_name, choice, seed):
model = Generator()
weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
torch.manual_seed(seed)
if choice == 'interpolation':
interpolate(model)
return 'out.gif'
else:
z = torch.randn(64, 100, 1, 1)
punks = model(z)
save_image(punks, "image.png", normalize=True)
img = Image.open(f"image.png").convert('RGBA')
img.putalpha(255)
img.save("image.png")
return 'image.png'
model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
st.set_page_config(page_title="Hugging NFT")
st.title("Hugging NFT")
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p class="aligncenter">
<img src="https://raw.githubusercontent.com/AlekseyKorshuk/optimum-transformers/master/data/social_preview.png" width="300" />
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p style='text-align: center'>
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">GitHub</a>
</p>
<p class="aligncenter">
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
</a>
</p>
<p class="aligncenter">
<a href="https://twitter.com/alekseykorshuk" target="_blank">
<img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/>
</a>
</p>
""",
unsafe_allow_html=True,
)
st.markdown(
"πŸ€— [Hugging NFT](https://github.com/AlekseyKorshuk/huggingnft) - Generate NFT by OpenSea collection name.")
st.markdown(
"πŸš€οΈ SN-GAN used to train all models.")
st.markdown(
"⁉️ Want to train your model? Check [project repository](https://github.com/AlekseyKorshuk/huggingnft) and make in in few clicks!")
#
# st.markdown("πŸš€ Up to 1ms on Bert-based transformers")
#
# st.markdown(
# "‼️ NOTE: This Space **does not show** the real power of this project because: low recources, not possbile to optimize models. Check [project repository](https://github.com/AlekseyKorshuk/optimum-transformers) with real bechmarks!")
# st.sidebar.header("Settings:")
model_name = st.selectbox(
'Choose model:',
model_names)
output_type = st.selectbox(
'Output type:',
['image', 'interpolation'])
seed_value = st.slider("Seed:",
min_value=1,
max_value=1000,
step=1,
value=100,
)
model_html = """
<div class="inline-flex flex-col" style="line-height: 1.5;">
<div class="flex">
<div
\t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url(&#39;USER_PROFILE&#39;)">
</div>
</div>
<div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">πŸ€– HuggingArtists Model πŸ€–</div>
<div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div>
<a href="https://genius.com/artists/USER_HANDLE">
\t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div>
</a>
</div>
"""
if st.button("Run"):
with st.spinner(text=f"Generating..."):
st.image(predict(model_name, output_type, seed_value))
st.subheader("Please star project repository, this space and follow my Twitter:")
st.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p class="aligncenter">
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
</a>
</p>
<p class="aligncenter">
<a href="https://twitter.com/alekseykorshuk" target="_blank">
<img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/>
</a>
</p>
""",
unsafe_allow_html=True,
)