|
import streamlit as st |
|
|
|
from utils import generation,load_model |
|
|
|
|
|
st.title("Gen of butterfly") |
|
st.markdown("This is lightweight_gan") |
|
|
|
|
|
|
|
st.sidebar.subheader("This is generated") |
|
st.sidebar.image("assets/logo.png", width=200) |
|
st.sidebar.caption("https://wgcv.me") |
|
|
|
|
|
model_id="ceyda/butterfly_cropped_uniq1K_512" |
|
model = load_model(model_id) |
|
n_gen = 16 |
|
|
|
def run(): |
|
with st.spinner("Loading the model"): |
|
|
|
ims = generation(model,batch_size=n_gen) |
|
st.session_state["ims"] = ims |
|
|
|
if("ims" not in st.session_state): |
|
st.session_state["ims"] = None |
|
run() |
|
|
|
|
|
ims = st.session_state["ims"] |
|
run_button = st.button("Gen AI butterfly", on_click=run,help="This would run the model") |
|
|
|
if(ims is not None): |
|
cols = st.columns(n_gen) |
|
for j,im in enumerate(ims): |
|
i = j % n_gen |
|
cols[i].image(im, use_column_width=True) |
|
|