huggingnft / app.py
AlekseyKorshuk's picture
Update app.py
fd63a18
import json
from huggingnft.lightweight_gan.train import timestamped_filename
from streamlit_option_menu import option_menu
from huggingface_hub import hf_hub_download, file_download
from huggingface_hub.hf_api import HfApi
import streamlit as st
from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer
from accelerate import Accelerator
hfapi = HfApi()
model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
# streamlit-option-menu
# st.set_page_config(page_title="Sharone's Streamlit App Gallery", page_icon="", layout="wide")
# sysmenu = '''
# <style>
# #MainMenu {visibility:hidden;}
# footer {visibility:hidden;}
# '''
# st.markdown(sysmenu,unsafe_allow_html=True)
# # Add a logo (optional) in the sidebar
# logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png')
# profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png')
ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name."
CONTACT_TEXT = "Here is some contact info"
GENERATE_IMAGE_TEXT = "Text about generation"
INTERPOLATION_TEXT = "Text about Interpolation"
COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection"
STOPWORDS = ["-old"]
COLLECTION2COLLECTION_KEYS = ["2"]
def load_lightweight_model(model_name):
file_path = file_download.hf_hub_download(
repo_id=model_name,
filename="config.json"
)
config = json.loads(open(file_path).read())
organization_name, name = model_name.split("/")
model = Trainer(**config, organization_name=organization_name, name=name)
model.load(use_cpu=True)
model.accelerator = Accelerator()
return model
def clean_models(model_names, stopwords):
cleaned_model_names = []
for model_name in model_names:
clear = True
for stopword in stopwords:
if stopword in model_name:
clear = False
break
if clear:
cleaned_model_names.append(model_name)
return cleaned_model_names
model_names = clean_models(model_names, STOPWORDS)
with st.sidebar:
choose = option_menu("Hugging NFT",
["About", "Generate image", "Interpolation", "Collection2Collection", "Contact"],
icons=['house', 'camera fill', 'bi bi-youtube', 'book', 'person lines fill'],
menu_icon="app-indicator", default_index=0,
)
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p style='text-align: center'>
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">Project Repository</a>
</p>
<p class="aligncenter">
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
</a>
</p>
<p class="aligncenter">
<a href="https://twitter.com/alekseykorshuk" target="_blank">
<img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/>
</a>
</p>
""",
unsafe_allow_html=True,
)
if choose == "About":
st.title(choose)
st.markdown(ABOUT_TEXT)
if choose == "Contact":
st.title(choose)
st.markdown(CONTACT_TEXT)
if choose == "Generate image":
st.title(choose)
st.markdown(GENERATE_IMAGE_TEXT)
model_name = st.selectbox(
'Choose model:',
clean_models(model_names, COLLECTION2COLLECTION_KEYS)
)
generation_type = st.selectbox(
'Select generation type:',
["default", "ema"]
)
nrows = st.number_input("Number of rows:",
min_value=1,
max_value=10,
step=1,
value=8,
)
generate_image_button = st.button("Generate")
if generate_image_button:
with st.spinner(text=f"Downloading selected model..."):
model = load_lightweight_model(f"huggingnft/{model_name}")
with st.spinner(text=f"Generating..."):
st.image(
model.generate_app(
num=timestamped_filename(),
nrow=nrows,
checkpoint=-1,
types=generation_type
)
)
if choose == "Interpolation":
st.title(choose)
st.markdown(INTERPOLATION_TEXT)
model_name = st.selectbox(
'Choose model:',
clean_models(model_names, COLLECTION2COLLECTION_KEYS)
)
nrows = st.number_input("Number of rows:",
min_value=1,
max_value=10,
step=1,
value=1,
)
num_steps = st.number_input("Number of steps:",
min_value=1,
max_value=1000,
step=1,
value=100,
)
generate_image_button = st.button("Generate")
if generate_image_button:
with st.spinner(text=f"Downloading selected model..."):
model = load_lightweight_model(f"huggingnft/{model_name}")
my_bar = st.progress(0)
result = model.generate_interpolation(
num=timestamped_filename(),
num_image_tiles=nrows,
num_steps=num_steps,
save_frames=False,
progress_bar=my_bar
)
my_bar.empty()
with st.spinner(text=f"Uploading result..."):
st.image(result)
if choose == "Collection2Collection":
st.title(choose)
st.markdown(INTERPOLATION_TEXT)
model_name = st.selectbox(
'Choose model:',
set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS))
)
generate_image_button = st.button("Generate")
if generate_image_button:
st.markdown("generating Collection2Collection")