GENsONG_PRC / app.py
bdjwhdwjb
Update app.py
2391729 verified
raw
history blame
4.2 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import urllib.parse
# Set Streamlit page configuration
st.set_page_config(
page_title="Text-to-Music Generator 🎡",
page_icon="🎡",
layout="wide",
initial_sidebar_state="expanded",
)
# Add custom CSS for styling
st.markdown("""
<style>
body {
background-color: #eaf6fb;
color: #003366;
font-family: 'Arial', sans-serif;
}
.stButton>button {
background-color: #4dabf5;
color: white;
font-weight: bold;
border-radius: 12px;
padding: 10px 20px;
}
.stButton>button:hover {
background-color: #007bb5;
}
.stTextArea textarea {
border: 2px solid #4dabf5;
border-radius: 8px;
}
iframe {
border: 2px solid #4dabf5;
border-radius: 8px;
}
.title {
text-align: center;
font-size: 36px;
font-weight: bold;
color: #003366;
}
.description {
text-align: center;
font-size: 18px;
color: #005792;
}
</style>
""", unsafe_allow_html=True)
# Initialize the Hugging Face model and tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music')
model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music')
return tokenizer, model
# Load model and tokenizer
tokenizer, model = load_model()
# Streamlit App UI
st.markdown("<div class='title'>🎡 Text-to-Music Generator 🎡</div>", unsafe_allow_html=True)
st.markdown("""
<div class='description'>
Enter a textual description, and the model will generate music in ABC notation.
You can use tools like [abc2midi](http://abc.sourceforge.net/abcMIDI/) to convert the notation into a playable file.
</div>
""", unsafe_allow_html=True)
# Input Fields
with st.container():
text_input = st.text_area(
"Enter a description for the music:",
placeholder="e.g., This is a traditional Irish dance music.",
)
max_length = st.slider(
"Maximum Length of Generated Music:", min_value=128, max_value=2048, value=1024, step=128
)
top_p = st.slider(
"Top-p (Nucleus Sampling):", min_value=0.1, max_value=1.0, value=0.9, step=0.05
)
temperature = st.slider(
"Temperature (Sampling Diversity):", min_value=0.1, max_value=2.0, value=1.0, step=0.1
)
# Generate Music Button
if st.button("Generate Music 🎢"):
if not text_input.strip():
st.error("Please enter a valid description!")
else:
st.info("Generating music... This might take a few seconds.")
try:
# Tokenize input
input_ids = tokenizer(text_input, return_tensors='pt', truncation=True, max_length=max_length)['input_ids']
# Generate music using efficient beam search sampling
generated_ids = model.generate(
input_ids,
max_length=max_length,
do_sample=True,
top_p=top_p,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id,
)
# Decode generated music
tune = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
tune = "X:1\n" + tune
st.success("Music generated successfully!")
# Display raw generated music in the app
st.text_area("Generated Music (ABC Notation):", value=tune, height=300)
# Encode tune for URL and embed ABCJS Editor
encoded_tune = urllib.parse.quote(tune)
editor_url = f"https://www.abcjs.net/abcjs-editor?abc={encoded_tune}"
st.markdown(f"""
### ABCJS Editor Preview
You can edit or play the music below:
<iframe src="{editor_url}" width="100%" height="500" style="border:none;"></iframe>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"An error occurred: {e}")