Spaces:
Paused
Paused
File size: 1,723 Bytes
fe08291 f04c9cc cdd9a51 f04c9cc 585cc65 e133530 f04c9cc cdd9a51 585cc65 cdd9a51 585cc65 cdd9a51 f04c9cc cdd9a51 585cc65 ee3757e 585cc65 e133530 585cc65 f04c9cc fe08291 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import streamlit as st
from PIL import Image
from inference import inference
import torch
import io
from diffusion import DiffusionImageAPI
import math
def main():
genres_dict = {
'Action': 1,
'Adventure': 2,
'Animation': 3,
'Comedy': 4,
'Drama': 5,
'Family': 6,
'Horror': 7,
'Music': 8,
'Romance': 9,
'Science Fiction': 10,
'Western': 11,
'Fantasy': 12,
'Thriller': 13
}
st.title("Movie Diffusion")
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# Add a sidebar for genre selection
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys()))
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
progress_placeholder = st.empty()
image_placeholder = st.empty()
# Button to trigger image generation
if st.button('Generate Image'):
for genre in selected_genres:
code = genres_dict[genre]
cond[code-1] = code
if torch.any(cond != 0):
random_number = torch.randint(0, 13, (1,)).item()
cond[random_number] = random_number + 1
def callback(image, progress):
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Update the content of the placeholders
progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%")
image_placeholder.image(img_buffer, caption='Generated Image', width=300)
inference(cond, callback=callback)
if __name__ == "__main__":
main()
|