dup_demo / app.py
yeelou's picture
Upload 2 files
5c76941 verified
# -*- coding: utf-8 -*-
"""dup_demo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/16I3l4MYW9xUqidexTK9NLHRih2imPxoE
"""
# !pip install -qqq gradio datasketch vibrato zstandard
# !pip install --no-deps sentence-transformers
# !wget https://github.com/daac-tools/vibrato/releases/download/v0.5.0/ipadic-mecab-2_7_0.tar.xz
# !tar xf ipadic-mecab-2_7_0.tar.xz
import os
import gradio as gr
import vibrato
import zstandard
from datasketch import MinHash, MinHashLSH
from sentence_transformers import SentenceTransformer, util
# import wget
# import tarfile
if not os.path.exists("./ipadic-mecab-2_7_0.tar.xz"):
# wget.download("https://github.com/daac-tools/vibrato/releases/download/v0.5.0/ipadic-mecab-2_7_0.tar.xz")
# with tarfile.open('ipadic-mecab-2_7_0.tar.xz', 'r:xf') as tar:
# tar.extractall('ipadic-mecab-2_7_0')
os.system(
"wget https://github.com/daac-tools/vibrato/releases/download/v0.5.0/ipadic-mecab-2_7_0.tar.xz"
)
os.system("tar xf ipadic-mecab-2_7_0.tar.xz")
dctx = zstandard.ZstdDecompressor()
with open("ipadic-mecab-2_7_0/system.dic.zst", "rb") as fp:
with dctx.stream_reader(fp) as dict_reader:
tokenizer = vibrato.Vibrato(dict_reader.read())
dup_model_dict = {
"cased": SentenceTransformer(
"sentence-transformers/distiluse-base-multilingual-cased-v2"
),
"cased_train": SentenceTransformer(
"yeelou/news-demo-trainingdup-multilingual-cased-v2"
),
}
exp_list = [
"金融庁は、朝日火災海上保険(以下、朝日火災)が20年以上にわたり無資格で保険を販売していたとして、保険業法に基づく業務改善命令を発出したことが読売新聞と日本経済新聞の報道で明らかになりました。この業務改善命令は、2020年12月28日に行われました。朝日火災は、長年にわたり保険業法を遵守しない行為を続けてきたことが問題視されています。",
"読売新聞・日本経済新聞によると、金融庁は12月28日(UTC+9、以下同様)に、20年以上に亘り無資格で保険を販売していたとして、朝日火災海上保険(以下、「朝日火災」)に対し、保険業法に基づく業務改善命令を出した。また、同社の所属代理店であるヤマト運輸と沖縄ヤマト運輸に対しても、2010年1月15日から1週間の保険販売停止と業務改善命令を出した。",
"インド洋北東部のベンガル湾で発生したサイクロンが、沿岸部のインドとバングラデシュに甚大な被害をもたらしたと、ロイターと朝日新聞が報じています。この自然災害は、両国の多くの地域に深刻な影響を与え、人々の生活に大きな打撃を与えました。被害の全容はまだ明らかになっていませんが、地方によっては被害状況の把握と報告が進んでいるところもあります。一方で、被害者数に関する情報は混乱しており、一部報道機関は当初報じた数字を訂正しています。",
"ロイター、朝日新聞によると、インド洋北東部のベンガル湾でサイクロンが発生し、沿岸のインドとバングラデシュに甚大な被害を及ぼした。被害者についての情報は混乱している。地方によって被害の把握により報告が増すところがある一方、一部の報道機関は当初伝えた被害者の数を訂正している。",
"沖縄気象台は9日(UTC+9)、同日ごろに沖縄地方が梅雨明けしたとみられると発表した。平年の6月23日より14日、昨年の6月19日より10日早い。琉球新報によると、1951年の気象庁の統計開始以降、史上最も早い梅雨明けとなる。",
"沖縄気象台は、2024年6月9日(UTC+9)に沖縄地方が梅雨明けしたと見られると発表しました。これは平年の6月23日よりも14日早く、また前年の6月19日よりも10日早い梅雨明けとなります。琉球新報の報道によると、1951年に気象庁が統計を開始して以来、これが史上最も早い梅雨明けと記録されました。",
"警視庁八王子警察署は、1995年7月30日に八王子市内のスーパーで起きた拳銃殺害事件で、犯人逮捕のための有力な情報に対し、300万円の懸賞金を支払うことを明らかにした。",
"警視庁八王子警察署は、1995年7月30日に八王子市内のスーパーで発生した拳銃による殺害事件に関連して、犯人の逮捕につながる重要な情報提供者に対し、300万円の懸賞金の支払いを発表しました。",
"対話型AIサービス「チャットGPT」を開発した米オープンAIは8日、サム・アルトマン最高経営責任者が取締役に復帰すると発表した。",
"対話型AI(人工知能)サービス「チャットGPT」を開発した米オープンAIは8日、サム・アルトマン最高経営責任者(CEO)が取締役に復帰すると発表した。",
]
minh_threshold = 0.8
minh_num_perm = 128
mk_output = """
* 本文_{idx} : {text}
* 類似文 : {dup_text}
* スコア : {score}
---
"""
def tokenized(text):
tokenized = []
tokens = tokenizer.tokenize(text)
tokens = [token.surface() for token in tokens]
tokenized.extend(tokens)
return tokenized
def get_minhash(text):
text_tokenized = tokenized(text)
m_tmp = MinHash(num_perm=minh_num_perm)
for d in text_tokenized:
m_tmp.update(d.encode("utf8"))
return m_tmp
def lsh_query(text, texts, lsh):
q_min = get_minhash(text)
return [(i, q_min.jaccard(get_minhash(texts[int(i)]))) for i in lsh.query(q_min)]
def minhash(texts):
lsh = MinHashLSH(threshold=minh_threshold, num_perm=minh_num_perm)
for idx, val in enumerate(texts):
lsh.insert(f"{idx}", get_minhash(val))
dup_dict = {}
dup_list = []
for idx, val in enumerate(texts):
if val not in dup_list:
res_query = lsh_query(val, texts, lsh)
res_query = [
[texts[int(i[0])], i[1]] for i in res_query if int(i[0]) != idx
]
if res_query == []:
dup_dict[val] = [("no duplicate news", 0)]
else:
dup_dict[val] = res_query
for x in res_query:
dup_list.append(x[0])
else:
pass
return dup_dict
def run_dup_model(model, sim, texts):
paraphrases = util.paraphrase_mining(model, texts)
news_paraphrases = paraphrases[: len(texts)]
res_cased = [p for p in news_paraphrases if p[0] >= sim]
res_cased_dict = {}
for i in res_cased:
# print(i)
t = texts[i[1]]
v = texts[i[2]]
s = i[0]
if t not in res_cased_dict.keys():
res_cased_dict[t] = []
if v not in res_cased_dict.keys():
res_cased_dict[v] = []
res_cased_dict[t].append([v, s])
dup_dict = {}
for text in texts:
if text not in res_cased_dict.keys():
dup_dict[text] = [("no duplicate news", 0)]
elif res_cased_dict[text] == []:
pass
else:
dup_dict[text] = res_cased_dict[text]
return dup_dict
# dup_model_dict.keys()
# res = minhash(exp_list)
# res = run_dup_model(dup_model_dict['cased'], minh_threshold, exp_list)
# mk_all = ""
# for key, vals in res.items():
# for val in vals:
# mk_tmp = mk_output.format(
# text = key,
# dup_text = val[0],
# score = val[1],
# )
# mk_all+=mk_tmp
def dup_report(choice, *args):
texts = list(args)
print("choice = ", choice)
print("texts = ", texts)
print("texts nums = ", len(texts))
texts = [i for i in texts if i != ""]
if choice == "minhash":
res = minhash(texts)
else:
res = run_dup_model(dup_model_dict[choice], minh_threshold, exp_list)
mk_all = ""
for idx, (key, vals) in enumerate(res.items()):
for val in vals:
mk_tmp = mk_output.format(
idx=idx,
text=key,
dup_text=val[0],
score=val[1],
)
mk_all += mk_tmp
return mk_all
with gr.Blocks(title="類似文検索POC", theme="bethecloud/storj_theme") as demo:
gr.Markdown("# 類似度比較POC")
gr.Markdown(
"""
* POC内容:重複処理
* 以下三つの方法で例文の中から類似文を比較する。
* minhash: ハッシュ計算で類似度を比較。(一般的に文字の類似度比較に使用される)
* cased: 既存モデルで、Cos方法で類似度比較。
* cased_train: 60件のトレーニングデータを用いて再学習してからCos方法で類似度比較。
"""
)
choice = gr.Radio(
choices=["minhash", "cased", "cased_train"], label="検索方法", value="minhash"
)
input_dict = {}
for idx, exp in enumerate(exp_list):
input_dict[f"{idx}"] = gr.Textbox(label=f"ニュース {idx+1}", value=exp)
gen_btn = gr.Button("比較")
gr.Markdown("アウトプット")
output = gr.Markdown(label="レポート")
# gen_btn.click(fn=dup_report, inputs=[input_dict, choice], outputs=output)
gen_btn.click(
fn=dup_report,
inputs=[choice] + [input_dict[i] for i in input_dict.keys()],
outputs=output,
)
demo.launch(inline=False, share=True, debug=True)