Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import imageio.v3 as iio | |
from inference import inference | |
from src.utils.opt import Opts | |
import os | |
st.set_page_config(layout='wide') | |
st.markdown( | |
""" | |
<style> | |
div[data-testid="column"]:nth-of-type(1) | |
{ | |
} | |
div[data-testid="column"]:nth-of-type(2) | |
{ | |
} | |
</style> | |
""",unsafe_allow_html=True | |
) | |
col1, col2, col3 = st.columns(3) | |
if 'counter' not in st.session_state: | |
st.session_state.video_path = None | |
st.session_state.image = None | |
st.session_state.counter = 0 | |
def showVideo(image): | |
if st.session_state.image is not None: | |
cfg = Opts(cfg="configs/style_inference.yml").parse_args() | |
result = inference(cfg, "render_test", image=image) | |
st.session_state.video_path = result["video_path"] | |
st.session_state.counter += 1 | |
else: | |
col2.write("No uploaded image") | |
with col1: | |
col1.subheader("Source multiview images") | |
filteredImages = [] | |
for image_file in os.listdir('data/nerf_llff_data/trex/streamlit_images'): | |
filteredImages.append(Image.open(os.path.join('data/nerf_llff_data/trex/streamlit_images', image_file))) | |
id = 0 | |
for img in range(0, len(filteredImages), 4): | |
cols = col1.columns(4) | |
cols[0].image(filteredImages[id], use_column_width=True) | |
id +=1 | |
cols[1].image(filteredImages[id], use_column_width=True) | |
id +=1 | |
cols[2].image(filteredImages[id], use_column_width=True) | |
id +=1 | |
cols[3].image(filteredImages[id], use_column_width=True) | |
id +=1 | |
with col2: | |
col2.subheader("Style image") | |
uploaded_file = col2.file_uploader("Choose a image file") | |
if uploaded_file: | |
st.session_state.image = Image.open(uploaded_file) | |
img = col2.image(st.session_state.image, caption='Style Image', use_column_width=True) | |
col2.button('Run Style Transfer', on_click=showVideo, args=([st.session_state.image])) | |
col3.subheader("Style videos") | |
if st.session_state.counter > 0: | |
video_file = open(st.session_state.video_path, 'rb') | |
video_bytes = video_file.read() | |
col3.video(video_bytes) | |