AI_prog3 / app.py
morizon's picture
Add application file
deb58b0
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, RobertaForQuestionAnswering
# 0.モデルのロード, Examplesの準備
# 評価対象の要約モデル
tokenizer_sum = AutoTokenizer.from_pretrained("tsmatz/mt5_summarize_japanese")
model_sum = AutoModelForSeq2SeqLM.from_pretrained("tsmatz/mt5_summarize_japanese")
# 質問文の生成
tokenizer_gen_q = T5Tokenizer.from_pretrained("sonoisa/t5-base-japanese-question-generation")
model_gen_q = T5ForConditionalGeneration.from_pretrained("sonoisa/t5-base-japanese-question-generation")
# 回答の生成
tokenizer_qa = AutoTokenizer.from_pretrained("tsmatz/roberta_qa_japanese")
model_qa = RobertaForQuestionAnswering.from_pretrained("tsmatz/roberta_qa_japanese")
# Example 1
eg_text_1 = """
ポケットモンスターの原点は、1996年2月27日に発売されたゲームボーイ用ソフト『ポケットモンスター 赤・緑』である。
開発元はゲームフリーク。コンセプトメーカーにしてディレクターを務めたのは、同社代表取締役でもある田尻智。
この作品が小学生を中心に、口コミから火が点き大ヒットとなり、以降も多くの続編が発売されている(詳しくは「ポケットモンスター(ゲーム)」を参照)。
ゲーム本編作品だけでなく、派生作品や関連作品が数多く発売されている(詳しくはポケットモンスターの関連ゲームを参照)。
ポケモンはゲームのみならず、アニメ化、キャラクター商品化、カードゲーム、アーケードゲームと様々なメディアミックス展開がなされ、日本国外でも人気を獲得している。
ポケモン関連ゲームソフトの累計出荷数は、全世界で2017年11月時点で3億本以上[1]、2022年3月時点で4億4000万本以上に達している[2]。
その中で、メインシリーズの累計販売本数は2016年2月時点での最新作、ニンテンドー3DS『オメガルビー・アルファサファイア』までの25作品で2億100万本となる[3]。
"""
eg_ans_1_1 = "2月27日"
eg_ans_1_2 = "ポケットモンスター 赤・緑"
# Example 2
eg_text_2 = """
アンパンマンの生みの親であるやなせたかしの作品で1968年に「バラの花とジョー」、
「チリンの鈴」の絵本や映画にいち早くアンパンマンが登場しているが、この時はまだ人間の姿。
この童話は一年間連載された。[5]アンパンマン、やなせたかしの作品としての、「アンパンマン」は、
PHP研究所が発行する青年向け雑誌『PHP』の通巻第257号に当たる、『こどものえほん』の1969年10月号[6](同年10月1日刊行)に掲載された青年向け読物、
やなせたかし(絵と文)「アンパンマン」という形が初出である[7][8][9]。
この時期、やなせが『こどものえほん』のために執筆した読物は連載12本の短編で、「アンパンマン」はその6本目の作品であった。
これら12篇は、株式会社山梨シルクセンター(※3年後、株式会社サンリオへ社名変更)より単行本『十二の真珠』名義で1970年に刊行された。
空腹に喘ぐ人の所へ駆け付けて、自らの大事な持ち物であるパンを差し出して食べるよう勧めるという、のちのアンパンマンに通じる物語の骨組みが、
この作品のおいて早くも整えられている[10][6]。
絵本・漫画・アニメなど、のちに描かれるアンパンマンとの大きな違いと言えば、第一に主人公のアンパンマンが普通の人間のおじさんであり[10][6]、
パンは所有物に過ぎなかったことである。
"""
eg_ans_2_1 = "アンパンマン"
eg_ans_2_2 = "やなせたかし"
# Example 3
eg_text_3 = """
企業の宇宙進出が進み、宇宙移民者スペーシアンと地球居住者アーシアンの対立が激化する時代。小惑星フロント「フォールクヴァング」にあるヴァナディース機関のラボでは、
カルド・ナボの主導のもと、地球のオックス・アース・コーポレーションのMSガンダムが開発されていた。しかしガンダムに採用されたGUNDフォーマットの健全性を
試作機のガンダム・ルブリスは証明できず、テストパイロットのエルノラ・サマヤは焦りを感じる。そんな中、GUNDフォーマットを危険視するデリング・レンブランの差し金で、
MS開発評議会はオックス社への企業行政法による強制執行を決定。評議会配下の特殊部隊「ドミニコス隊」がフォールクヴァングに派遣され、エルノラの夫ナディム・サマヤたちの
抵抗も空しく制圧作戦が進められる。カルドを始め多くの人命が犠牲となる中、ルブリスが偶然乗り込んでいたエルノラの娘エリクト・サマヤの生体情報を認証して起動する事態が発生し、
敵MSを撃墜したエリクトに驚愕しながらもエルノラはルブリスで脱出を図る。ナディムの自己犠牲によってエルノラたちは宙域からの離脱には成功するが、フォールクヴァングは破壊され、
ガンダムの開発計画はすべて凍結される。
"""
eg_ans_3_1 = "カルド・ナボ"
eg_ans_3_2 = "GUNDフォーマット"
# 1. イベント用の関数
def summy(text):
"""要約
Args
text: str
要約対象のテキスト
Returns
summarize_text: str
要約結果のテキスト
"""
inputs = tokenizer_sum("summarize: " + text, return_tensors="pt")
output = model_sum.generate(
inputs["input_ids"],
max_new_tokens=300, # 生成数の上限
min_length=150, # 生成数の下限
num_beams=5 # ビームサーチの設定
)
summarize_text = tokenizer_sum.decode(output[0], skip_special_tokens=True)
return summarize_text
def generate_questions(answer_1, answer_2, text):
"""質問生成
Args
answers: list[str]
質問生成のための正解単語のリスト
text: str
質問文を生成する際に参照するテキスト
Returns
generated_questions: list[str]
生成された質問文のリスト
"""
answer_context_list = [(answer, text) for answer in [answer_1, answer_2]] # 解答を質問生成する元となる文(要約結果)とセットにする。
generated_questions = []
for answer, context in answer_context_list:
# モデルに入力可能な形式に変換する
# 「answer: 」と「context: 」を使った形式に変換にする
input = tokenizer_gen_q(f"answer: {answer} context: {context}", return_tensors="pt")
# 質問文を生成する
output = model_gen_q.generate(
input['input_ids'],
max_new_tokens=100,
num_beams=4 # ビームサーチの設定
)
# 生成された問題文のトークン列を文字列に変換する。
output = tokenizer_gen_q.decode(output[0], skip_special_tokens=True)
generated_questions.append(output)
return generated_questions
def extract_answer(question, text):
"""質問応答
Args
question: str
質問文のテキスト
text: str
質問に回答するために参照するテキスト
Returns
answer: str
回答のテキスト
"""
inputs = tokenizer_qa(question, text, return_tensors="pt") # tokenizerには複数のテキストを与える
# 正解箇所の予測
outputs = model_qa(**inputs)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits
# 予測結果の開始と終了のインデックスを取得
answer_start = torch.argmax(answer_start_scores)
answer_end = torch.argmax(answer_end_scores) + 1
# tokenizerの結果から正解を抽出する
input_ids = inputs["input_ids"].tolist()[0]
answer = tokenizer_qa.decode(input_ids[answer_start:answer_end])
return answer
def extract_answer_all(gen_q_1, gen_q_2, source_text, sum_text):
"""extract_answer()をまとめて実行する
"""
a_source_1 = extract_answer(gen_q_1, source_text)
a_sum_1 = extract_answer(gen_q_1, sum_text)
a_source_2 = extract_answer(gen_q_2, source_text)
a_sum_2 = extract_answer(gen_q_2, sum_text)
return a_source_1, a_sum_1, a_source_2, a_sum_2
# 2. UIの定義
with gr.Blocks() as demo:
gr.Markdown("### 1. 要約生成")
source_text = gr.Textbox(label="要約対象")
btn_summy = gr.Button("要約生成")
sum_text = gr.Textbox(label="要約結果")
gr.Markdown("### 2. 質問生成")
with gr.Row():
with gr.Column():
answer_1 = gr.Text(label="正解1")
with gr.Column():
answer_2 = gr.Text(label="正解2")
btn_generate_questions = gr.Button("質問生成")
gr.Markdown("### 3. 回答生成")
with gr.Row():
with gr.Column():
gen_q_1 = gr.Text(label="1番目の質問")
with gr.Column():
gen_q_2 = gr.Text(label="2番目の質問")
btn_extract_answer = gr.Button("回答生成")
with gr.Row():
with gr.Column():
a_source_1 = gr.Text(label="sourceからの答え")
a_sum_1 = gr.Text(label="sumからの答え")
with gr.Column():
a_source_2 = gr.Text(label="sourceからの答え")
a_sum_2 = gr.Text(label="sumからの答え")
# 2. イベント発火
btn_summy.click(summy, inputs=[source_text], outputs=[sum_text])
btn_generate_questions.click(generate_questions, inputs=[answer_1, answer_2, sum_text], outputs=[gen_q_1, gen_q_2])
btn_extract_answer.click(extract_answer_all,
inputs=[gen_q_1, gen_q_2, source_text, sum_text],
outputs=[a_source_1, a_sum_1, a_source_2, a_sum_2]
)
# Examplesの定義
gr.Markdown("## Examples")
gr.Examples(
[[eg_text_1, eg_ans_1_1, eg_ans_1_2], [eg_text_2, eg_ans_2_1, eg_ans_2_2],[eg_text_3, eg_ans_3_1, eg_ans_3_2]],
[source_text, answer_1, answer_2],
)
demo.launch()