|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
class AppModel: |
|
def __init__(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.model = AutoModelForCausalLM.from_pretrained("gpt2-medium") |
|
|
|
def generate_plot(self, prompt: str): |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu") |
|
st.write("Input tensor:", inputs) |
|
outputs = self.model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=100, |
|
do_sample=True, |
|
top_k=5, |
|
top_p=0.35, |
|
temperature=0.2, |
|
num_return_sequences=1, |
|
) |
|
st.write("Generated output:", outputs) |
|
output_string = self.tokenizer.batch_decode(outputs) |
|
st.write("Decoded output:", output_string) |
|
return output_string |
|
|
|
model = AppModel() |
|
|
|
st.title("Welcome to the GPT Olympics generator") |
|
prompt = st.text_area("Enter the beginning of your plot...") |
|
clicked = st.button("Generate my movie") |
|
|
|
if clicked: |
|
st.write("Clicked!") |
|
generated_plot = model.generate_plot(prompt) |
|
st.write("Generated plot:", generated_plot) |
|
if generated_plot: |
|
st.write("Assistant:") |
|
st.markdown(generated_plot[0]) |
|
else: |
|
st.write("No plot generated.") |
|
|
|
with open('./style.css') as f: |
|
css = f.read() |
|
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True) |
|
|
|
|