otrotest / app.py
juan9's picture
Update app.py
5c083c3 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import streamlit as st
st.set_page_config(layout="wide")
class AppModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
def generate_plot(self, prompt: str):
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
st.write("Input tensor:", inputs)
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=100,
do_sample=True,
top_k=5,
top_p=0.35,
temperature=0.2,
num_return_sequences=1,
)
st.write("Generated output:", outputs)
output_string = self.tokenizer.batch_decode(outputs)
st.write("Decoded output:", output_string)
return output_string
model = AppModel()
st.title("Welcome to the GPT Olympics generator")
prompt = st.text_area("Enter the beginning of your plot...")
clicked = st.button("Generate my movie")
if clicked:
st.write("Clicked!")
generated_plot = model.generate_plot(prompt)
st.write("Generated plot:", generated_plot)
if generated_plot:
st.write("Assistant:")
st.markdown(generated_plot[0])
else:
st.write("No plot generated.")
with open('./style.css') as f:
css = f.read()
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)