otrotest / app.py
juan9's picture
Upload 10 files
7a83cb9 verified
raw
history blame
1.43 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import streamlit as st
import streamlit_antd_components as sac
st.set_page_config(layout="wide")
@st.cache_resource
class AppModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained("./output/model/checkpoint-1000/")
def generate_plot(self, prompt: str):
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=100,
do_sample=True,
top_k=5,
top_p=0.35,
temperature= 0.2,
num_return_sequences= 1,
)
output_string = self.tokenizer.batch_decode(outputs)
return output_string
model = AppModel()
st.title("Welcome to the GPT Olympics generator")
prompt = st.text_area("Enter the beginning of your plot...")
clicked = st.button("Generate my movie")
sac.divider(label='label', icon='house', align='center', size='md', color='yellow')
if clicked:
generated_plot = model.generate_plot(prompt)[0]
chat_message = st.chat_message("assistant")
chat_message.markdown(generated_plot)
with open('./style.css') as f:
css = f.read()
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)