Spaces:
Running
Running
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import streamlit as st | |
# Load the tokenizer and model | |
tokenizer = GPT2Tokenizer.from_pretrained('webtoon_tokenizer') | |
model = GPT2LMHeadModel.from_pretrained('webtoon_model') | |
# Define the app | |
def main(): | |
st.title('Webtoon Description Generator') | |
# Get the input from the user | |
title = st.text_input('Enter the title of the Webtoon:', '') | |
# Generate the description | |
if st.button('Generate Description'): | |
with st.spinner('Generating...'): | |
description = generate_description(title) | |
st.success(description) | |
# Check if GPU is available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# Define the function that generates the description | |
def generate_description(title): | |
# Preprocess the input | |
input_text = f"{title}" | |
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) | |
# Generate the output using the model | |
output = model.generate( | |
input_ids=input_ids, | |
max_length=200, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=2 | |
) | |
# Convert the output to text | |
description = tokenizer.decode(output[0], skip_special_tokens=True) | |
return description | |
if __name__ == '__main__': | |
main() |