Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoConfig, | |
AutoModelForCausalLM, | |
StoppingCriteriaList, | |
MaxLengthCriteria, | |
) | |
app = FastAPI() | |
#input_prompt = "Heart is in love" | |
def song_generator(input_prompt): | |
tokenizer = AutoTokenizer.from_pretrained("./TaylorSwiftFineTunedModel/") | |
model = AutoModelForCausalLM.from_pretrained("./TaylorSwiftFineTunedModel/") | |
# set pad_token_id to eos_token_id because OPT does not have a PAD token | |
model.config.pad_token_id = model.config.eos_token_id | |
input_ids = tokenizer(input_prompt, return_tensors="pt") | |
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=300)]) | |
outputs = model.contrastive_search( | |
**input_ids, penalty_alpha=0.6, top_k=15, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, | |
) | |
song_generated = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return song_generated | |
st.header('Taylor-swift style song generator') | |
query = st.text_input("Enter 2 or 3 verses ", "") | |
submit = st.button('Generate') | |
input_song = query | |
if submit: | |
st.subheader('Song generated is ') | |
with st.spinner(text='This may take a moment...'): | |
output_sentence = song_generator(input_song) | |
st.write(output_sentence[0]) |