song_generator / app.py
Pushpahasa's picture
Update app.py
791e003
from fastapi import FastAPI
import streamlit as st
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
StoppingCriteriaList,
MaxLengthCriteria,
)
app = FastAPI()
#input_prompt = "Heart is in love"
def song_generator(input_prompt):
tokenizer = AutoTokenizer.from_pretrained("./TaylorSwiftFineTunedModel/")
model = AutoModelForCausalLM.from_pretrained("./TaylorSwiftFineTunedModel/")
# set pad_token_id to eos_token_id because OPT does not have a PAD token
model.config.pad_token_id = model.config.eos_token_id
input_ids = tokenizer(input_prompt, return_tensors="pt")
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=300)])
outputs = model.contrastive_search(
**input_ids, penalty_alpha=0.6, top_k=15, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id,
)
song_generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return song_generated
st.header('Taylor-swift style song generator')
query = st.text_input("Enter 2 or 3 verses ", "")
submit = st.button('Generate')
input_song = query
if submit:
st.subheader('Song generated is ')
with st.spinner(text='This may take a moment...'):
output_sentence = song_generator(input_song)
st.write(output_sentence[0])