from fastapi import FastAPI import streamlit as st from transformers import ( AutoTokenizer, AutoConfig, AutoModelForCausalLM, StoppingCriteriaList, MaxLengthCriteria, ) app = FastAPI() #input_prompt = "Heart is in love" def song_generator(input_prompt): tokenizer = AutoTokenizer.from_pretrained("./TaylorSwiftFineTunedModel/") model = AutoModelForCausalLM.from_pretrained("./TaylorSwiftFineTunedModel/") # set pad_token_id to eos_token_id because OPT does not have a PAD token model.config.pad_token_id = model.config.eos_token_id input_ids = tokenizer(input_prompt, return_tensors="pt") stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=300)]) outputs = model.contrastive_search( **input_ids, penalty_alpha=0.6, top_k=15, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, ) song_generated = tokenizer.batch_decode(outputs, skip_special_tokens=True) return song_generated st.title('Taylor-swift style song generator') st.header('Song generation Model') query = st.text_input("Enter 2 or 3 verses ", "") submit = st.button('Generate') input_song = query if submit: st.subheader('Song generated is ') with st.spinner(text='This may take a moment...'): output_sentence = song_generator(input_song) st.write(output_sentence[0]) #output = song_generator(input_prompt) #@app.get("/") #def read_root(): # return output