Syrinx's picture
Update app.py
e195261
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import streamlit as st
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('webtoon_tokenizer')
model = GPT2LMHeadModel.from_pretrained('webtoon_model')
# Define the app
def main():
st.title('Webtoon Description Generator')
# Get the input from the user
title = st.text_input('Enter the title of the Webtoon:', '')
# Generate the description
if st.button('Generate Description'):
with st.spinner('Generating...'):
description = generate_description(title)
st.success(description)
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Define the function that generates the description
def generate_description(title):
# Preprocess the input
input_text = f"{title}"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
# Generate the output using the model
output = model.generate(
input_ids=input_ids,
max_length=200,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=2
)
# Convert the output to text
description = tokenizer.decode(output[0], skip_special_tokens=True)
return description
if __name__ == '__main__':
main()