Syrinx's picture
Update app.py
1e8f625 verified
raw
history blame contribute delete
No virus
1.69 kB
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import streamlit as st
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('webtoon_tokenizer')
model = GPT2LMHeadModel.from_pretrained('webtoon_model')
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define the function that generates the description
def generate_description(title):
# Preprocess the input
input_text = f"{title}"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
# Generate the output using the model
with torch.no_grad(): # Disable gradient calculation for inference
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Pass attention_mask to avoid warnings
max_length=100, # Reduce max_length for quicker inference
num_beams=2, # Reduce num_beams for quicker inference
early_stopping=True,
no_repeat_ngram_size=2
)
# Convert the output to text
description = tokenizer.decode(output[0], skip_special_tokens=True)
return description
# Define the app
def main():
st.title('Webtoon Description Generator')
# Get the input from the user
title = st.text_input('Enter the title of the Webtoon:', '')
# Generate the description
if st.button('Generate Description'):
with st.spinner('Generating...'):
description = generate_description(title)
st.success(description)
if __name__ == '__main__':
main()