Spaces:
Build error
Build error
import streamlit as st | |
from flask.Emotion_spotting_service import _Emotion_spotting_service | |
from flask.Genre_spotting_service import _Genre_spotting_service | |
from flask.Beat_tracking_service import _Beat_tracking_service | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import os | |
import logging | |
import psutil | |
import tensorflow as tf | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def print_memory_info(): | |
# Get free CPU memory | |
virtual_mem = psutil.virtual_memory() | |
free_cpu_mem = virtual_mem.available / (1024 ** 3) # Convert bytes to GB | |
# Get free GPU memory | |
free_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) - torch.cuda.memory_reserved(0) / (1024 ** 3) # Convert bytes to GB | |
logger.info(f"Free CPU Memory: {free_cpu_mem:.2f} GB") | |
logger.info(f"Free GPU Memory: {free_gpu_mem:.2f} GB") | |
emo_list = [] | |
gen_list = [] | |
tempo_list = [] | |
def load_emo_model(): | |
emo_service = _Emotion_spotting_service("flask/emotion_model.h5") | |
return emo_service | |
def load_genre_model(): | |
gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5") | |
return gen_service | |
def load_beat_model(): | |
beat_service = _Beat_tracking_service() | |
return beat_service | |
def load_image_model(): | |
torch.cuda.empty_cache() | |
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant='fp16') | |
pipeline.to("cuda") | |
pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") | |
return pipeline | |
if 'emotion' not in st.session_state: | |
st.session_state.emotion = None | |
if 'genre' not in st.session_state: | |
st.session_state.genre = None | |
if 'beat' not in st.session_state: | |
st.session_state.beat = None | |
logger.info(f"Measuring memory before load_emo_model") | |
print_memory_info() | |
emotion_service = load_emo_model() | |
logger.info(f"Measuring memory before load_genre_model") | |
print_memory_info() | |
genre_service = load_genre_model() | |
logger.info(f"Measuring memory before load_beat_model") | |
print_memory_info() | |
beat_service = load_beat_model() | |
logger.info(f"Measuring memory before load_image_model") | |
print_memory_info() | |
image_service = load_image_model() | |
logger.info(f"Measuring memory after load_image_model") | |
print_memory_info() | |