aino813's picture
Update app.py
1320441
raw
history blame
911 Bytes
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
# ใƒˆใƒผใ‚ฏใƒŠใ‚คใ‚ถใƒผใจใƒขใƒ‡ใƒซใฎๆบ–ๅ‚™
tokenizer = AutoTokenizer.from_pretrained('sonoisa/t5-base-japanese')
model = AutoModelForSeq2SeqLM.from_pretrained('models/')
def summary(text):
# ๆ–‡็ซ ใ‚’ใƒ†ใƒณใ‚ฝใƒซใซๅค‰ๆ›
input = tokenizer.encode(text, return_tensors='pt', max_length=512, truncation=True)
# ๆŽจ่ซ–
model.eval()
with torch.no_grad():
summary_ids = model.generate(input)
return tokenizer.decode(summary_ids[0][1:-1])
descriptions = "T5ใซใ‚ˆใ‚‹ๆ–‡็ซ ่ฆ็ด„ใ€‚ๆ–‡็ซ ใ‚’ๅ…ฅๅŠ›ใ™ใ‚‹ใจใ€ใใฎ่ฆ็ด„ๆ–‡ใ‚’ๅ‡บๅŠ›ใ—ใพใ™ใ€‚"
demo = gr.Interface(fn=summary, inputs=gr.Textbox(lines=5, placeholder="ๆ–‡็ซ ใ‚’ๅ…ฅๅŠ›ใ—ใฆใใ ใ•ใ„"), outputs=gr.Textbox(lines=5),title="Sentence Summary", description=descriptions)
demo.launch()