|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import streamlit as st |
|
import streamlit_antd_components as sac |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
@st.cache_resource |
|
class AppModel: |
|
def __init__(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.model = AutoModelForCausalLM.from_pretrained("./output/model/checkpoint-1000/") |
|
|
|
def generate_plot(self, prompt: str): |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu") |
|
outputs = self.model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=100, |
|
do_sample=True, |
|
top_k=5, |
|
top_p=0.35, |
|
temperature= 0.2, |
|
num_return_sequences= 1, |
|
) |
|
output_string = self.tokenizer.batch_decode(outputs) |
|
return output_string |
|
|
|
model = AppModel() |
|
|
|
st.title("Welcome to the GPT Olympics generator") |
|
prompt = st.text_area("Enter the beginning of your plot...") |
|
clicked = st.button("Generate my movie") |
|
|
|
sac.divider(label='label', icon='house', align='center', size='md', color='yellow') |
|
|
|
if clicked: |
|
generated_plot = model.generate_plot(prompt)[0] |
|
chat_message = st.chat_message("assistant") |
|
chat_message.markdown(generated_plot) |
|
|
|
with open('./style.css') as f: |
|
css = f.read() |
|
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|