import streamlit as st from flask.Emotion_spotting_service import _Emotion_spotting_service from flask.Genre_spotting_service import _Genre_spotting_service from flask.Beat_tracking_service import _Beat_tracking_service from diffusers import StableDiffusionPipeline import torch import os import logging import psutil import tensorflow as tf physical_devices = tf.config.experimental.list_physical_devices('GPU') if len(physical_devices) > 0: tf.config.experimental.set_memory_growth(physical_devices[0], True) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def print_memory_info(): # Get free CPU memory virtual_mem = psutil.virtual_memory() free_cpu_mem = virtual_mem.available / (1024 ** 3) # Convert bytes to GB # Get free GPU memory free_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) - torch.cuda.memory_reserved(0) / (1024 ** 3) # Convert bytes to GB logger.info(f"Free CPU Memory: {free_cpu_mem:.2f} GB") logger.info(f"Free GPU Memory: {free_gpu_mem:.2f} GB") emo_list = [] gen_list = [] tempo_list = [] @st.cache_resource def load_emo_model(): emo_service = _Emotion_spotting_service("flask/emotion_model.h5") return emo_service @st.cache_resource def load_genre_model(): gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5") return gen_service @st.cache_resource def load_beat_model(): beat_service = _Beat_tracking_service() return beat_service @st.cache_resource def load_image_model(): torch.cuda.empty_cache() pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant='fp16') pipeline.to("cuda") pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") return pipeline if 'emotion' not in st.session_state: st.session_state.emotion = None if 'genre' not in st.session_state: st.session_state.genre = None if 'beat' not in st.session_state: st.session_state.beat = None logger.info(f"Measuring memory before load_emo_model") print_memory_info() emotion_service = load_emo_model() logger.info(f"Measuring memory before load_genre_model") print_memory_info() genre_service = load_genre_model() logger.info(f"Measuring memory before load_beat_model") print_memory_info() beat_service = load_beat_model() logger.info(f"Measuring memory before load_image_model") print_memory_info() image_service = load_image_model() logger.info(f"Measuring memory after load_image_model") print_memory_info()