TechJamM2I / app.py
Monke64's picture
Changed requirements.txt
01859cc
raw
history blame
2.82 kB
import streamlit as st
from flask.Emotion_spotting_service import _Emotion_spotting_service
from flask.Genre_spotting_service import _Genre_spotting_service
from flask.Beat_tracking_service import _Beat_tracking_service
#from diffusers import StableDiffusionPipeline
import torch
emo_list = []
gen_list = []
tempo_list = []
@st.cache_resource
def load_emo_model():
emo_service = _Emotion_spotting_service("flask/emotion_model.h5")
return emo_service
@st.cache_resource
def load_genre_model():
gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5")
return gen_service
@st.cache_resource
def load_beat_model():
beat_service = _Beat_tracking_service()
return beat_service
#@st.cache_resource
# def load_image_model():
# pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
# pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
# return pipeline
if 'emotion' not in st.session_state:
st.session_state.emotion = None
if 'genre' not in st.session_state:
st.session_state.genre = None
if 'beat' not in st.session_state:
st.session_state.beat = None
emotion_service = load_emo_model()
genre_service = load_genre_model()
beat_service = load_beat_model()
#image_service = load_image_model()
st.title("Music2Image webpage")
user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
st.caption("Generate images from your audio file")
st.audio(user_input)
c1,c2,c3 = st.columns([1,1,1])
with c1:
if st.button("Generate emotion"):
emotion = emotion_service.predict(user_input)
st.session_state.emotion = emotion
st.text(st.session_state.emotion)
with c2:
if st.button("Generate genre"):
genre = genre_service.predict(user_input)
st.session_state.genre = genre
st.text(st.session_state.genre)
with c3:
if st.button("Generate beat"):
beat = beat_service.get_beat(user_input)
st.session_state.beat = beat
st.text(st.session_state.beat)
if st.session_state.emotion != None and st.session_state.genre != None and st.session_state.beat != None:
text_output = None
if st.button("Generate text description to be fed into stable diffusion"):
st.caption("Text description of your music file")
text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "."
st.text(text_output)
#if text_output:
# if st.button("Generate image from text description"):
#image = image_service(text_output)
#st.image(image)