yeelou commited on
Commit
528e6f6
1 Parent(s): 4ee97d9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +186 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """dup_demo.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/16I3l4MYW9xUqidexTK9NLHRih2imPxoE
8
+ """
9
+
10
+ # !pip install -qqq gradio datasketch vibrato zstandard
11
+ # !pip install --no-deps sentence-transformers
12
+ # !wget https://github.com/daac-tools/vibrato/releases/download/v0.5.0/ipadic-mecab-2_7_0.tar.xz
13
+ # !tar xf ipadic-mecab-2_7_0.tar.xz
14
+
15
+ import gradio as gr
16
+ import pandas as pd
17
+ import vibrato
18
+ import zstandard
19
+ import numpy as np
20
+ import os
21
+ import wget
22
+ import tarfile
23
+
24
+ from datasketch import MinHash, MinHashLSH
25
+ from sentence_transformers import SentenceTransformer, util
26
+
27
+ if not os.path.exists("./ipadic-mecab-2_7_0.tar.xz"):
28
+ wget.download("https://github.com/daac-tools/vibrato/releases/download/v0.5.0/ipadic-mecab-2_7_0.tar.xz")
29
+ with tarfile.open('ipadic-mecab-2_7_0.tar.xz', 'r:xf') as tar:
30
+ tar.extractall('ipadic-mecab-2_7_0')
31
+
32
+
33
+ dctx = zstandard.ZstdDecompressor()
34
+ with open('ipadic-mecab-2_7_0/system.dic.zst', 'rb') as fp:
35
+ with dctx.stream_reader(fp) as dict_reader:
36
+ tokenizer = vibrato.Vibrato(dict_reader.read())
37
+
38
+ dup_model_dict = {
39
+ "cased":SentenceTransformer("sentence-transformers/distiluse-base-multilingual-cased-v2"),
40
+ "cased_train":SentenceTransformer("yeelou/news-demo-trainingdup-multilingual-cased-v2"),
41
+ }
42
+
43
+ exp_list = ["金融庁は、朝日火災海上保険(以下、朝日火災)が20年以上にわたり無資格で保険を販売していたとして、保険業法に基づく業務改善命令を発出したことが読売新聞と日本経済新聞の報道で明らかになりました。この業務改善命令は、2020年12月28日に行われました。朝日火災は、長年にわたり保険業法を遵守しない行為を続けてきたことが問題視されています。",
44
+ "読売新聞・日本経済新聞によると、金融庁は12月28日(UTC+9、以下同様)に、20年以上に亘り無資格で保険を販売していたとして、朝日火災海上保険(以下、「朝日火災」)に対し、保険業法に基づく業務改善命令を出した。また、同社の所属代理店であるヤマト運輸と沖縄ヤマト運輸に対しても、2010年1月15日から1週間の保険販売停止と業務改善命令を出した。",
45
+ "インド洋北東部のベンガル湾で発生したサイクロンが、沿岸部のインドとバングラデシュに甚大な被害をもたらしたと、ロイターと朝日新聞が報じています。この自然災害は、両国の多くの地域に深刻な影響を与え、人々の生活に大きな打撃を与えました。被害の全容はまだ明らかになっていませんが、地方によっては被害状況の把握と報告が進んでいるところもあります。一方で、被害者数に関する情報は混乱しており、一部報道機関は当初報じた数字を訂正しています。",
46
+ "ロイター、朝日新聞によると、インド洋北東部のベンガル湾でサイクロンが発生し、沿岸のインドとバングラデシュに甚大な被害を及ぼした。被害者についての情報は混乱している。地方によって被害の把握により報告が増すところがある一方、一部の報道機関は当初伝えた被害者の数を訂正している。",
47
+ "沖縄気象台は9日(UTC+9)、同日ごろに沖縄地方が梅雨明けしたとみられると発表した。平年の6月23日より14日、昨年の6月19日より10日早い。琉球新報によると、1951年の気象庁の統計開始以降、史上最も早い梅雨明けとなる。",
48
+ "沖縄気象台は、2024年6月9日(UTC+9)に沖縄地方が梅雨明けしたと見られると発表しました。これは平年の6月23日よりも14日早く、また前年の6月19日よりも10日早い梅雨明けとなります。琉球新報の報道によると、1951年に気象庁が統計を開始して以来、これが史上最も早い梅雨明けと記録されました。",
49
+ "警視庁八王子警察署は、1995年7月30日に八王子市内のスーパーで起きた拳銃殺害事件で、犯人逮捕のための有力な情報に対し、300万円の懸賞金を支払うことを明らかにした。",
50
+ "警視庁八王子警察署は、1995年7月30日に八王子市内のスーパーで発生した拳銃による殺害事件に関連して、犯人の逮捕につながる重要な情報提供者に対し、300万円の懸賞金の支払いを発表しました。",
51
+ "対話型AIサービス「チャットGPT」を開発した米オープンAIは8日、サム・アルトマン最高経営責任者が取締役に復帰すると発表した。",
52
+ "対話型AI(人工知能)サービス「チャットGPT」を開発した米オープンAIは8日、サム・アルトマン最高経営責任者(CEO)が取締役に復帰す��と発表した。"
53
+ ]
54
+
55
+ minh_threshold = 0.8
56
+ minh_num_perm = 128
57
+
58
+ mk_output = """
59
+ * 本文_{idx} : {text}
60
+ * 類似文 : {dup_text}
61
+ * スコア : {score}
62
+
63
+ ---
64
+
65
+ """
66
+
67
+ def tokenized(text):
68
+ tokenized=[]
69
+ tokens = tokenizer.tokenize(text)
70
+ tokens = [token.surface() for token in tokens]
71
+ tokenized.extend(tokens)
72
+ return tokenized
73
+
74
+ def get_minhash(text):
75
+ text_tokenized = tokenized(text)
76
+ m_tmp = MinHash(num_perm=minh_num_perm)
77
+ for d in text_tokenized:
78
+ m_tmp.update(d.encode('utf8'))
79
+ return m_tmp
80
+
81
+ def lsh_query(text,texts,lsh):
82
+ q_min = get_minhash(text)
83
+ return [(i,q_min.jaccard(get_minhash(texts[int(i)]))) for i in lsh.query(q_min)]
84
+
85
+ def minhash(texts):
86
+ lsh = MinHashLSH(threshold=minh_threshold, num_perm=minh_num_perm)
87
+ for idx, val in enumerate(texts):
88
+ lsh.insert(f"{idx}", get_minhash(val))
89
+
90
+ dup_dict = {}
91
+ dup_list = []
92
+ for idx, val in enumerate(texts):
93
+ if val not in dup_list:
94
+ res_query = lsh_query(val,texts,lsh)
95
+ res_query = [[texts[int(i[0])],i[1]] for i in res_query if int(i[0]) != idx]
96
+ if res_query == []:
97
+ dup_dict[val] = [("no duplicate news",0)]
98
+ else:
99
+ dup_dict[val] = res_query
100
+ for x in res_query:
101
+ dup_list.append(x[0])
102
+ else:
103
+ pass
104
+ return dup_dict
105
+
106
+ def run_dup_model(model, sim, texts):
107
+ paraphrases = util.paraphrase_mining(model, texts)
108
+ news_paraphrases = paraphrases[:len(texts)]
109
+ res_cased = [p for p in paraphrases if p[0]>=sim]
110
+ res_cased_dict = {}
111
+ for i in res_cased:
112
+ # print(i)
113
+ t = texts[i[1]]
114
+ v = texts[i[2]]
115
+ s = i[0]
116
+ if t not in res_cased_dict.keys():
117
+ res_cased_dict[t]=[]
118
+ if v not in res_cased_dict.keys():
119
+ res_cased_dict[v]=[]
120
+ res_cased_dict[t].append([v,s])
121
+ dup_dict = {}
122
+ for text in texts:
123
+ if text not in res_cased_dict.keys():
124
+ dup_dict[text] = [("no duplicate news",0)]
125
+ elif res_cased_dict[text] == []:
126
+ pass
127
+ else:
128
+ dup_dict[text] = res_cased_dict[text]
129
+ return dup_dict
130
+
131
+ # dup_model_dict.keys()
132
+
133
+ # res = minhash(exp_list)
134
+ # res = run_dup_model(dup_model_dict['cased'], minh_threshold, exp_list)
135
+ # mk_all = ""
136
+ # for key, vals in res.items():
137
+ # for val in vals:
138
+ # mk_tmp = mk_output.format(
139
+ # text = key,
140
+ # dup_text = val[0],
141
+ # score = val[1],
142
+ # )
143
+ # mk_all+=mk_tmp
144
+
145
+ def dup_report(choice, *args):
146
+ texts = list(args)
147
+ print("choice = ", choice)
148
+ print("texts = ", texts)
149
+ print("texts nums = ", len(texts))
150
+ texts = [i for i in texts if i != ""]
151
+ if choice == "minhash":
152
+ res = minhash(texts)
153
+ else:
154
+ res = run_dup_model(dup_model_dict[choice], minh_threshold, exp_list)
155
+
156
+ mk_all = ""
157
+ for idx, (key, vals) in enumerate(res.items()):
158
+ for val in vals:
159
+ mk_tmp = mk_output.format(
160
+ idx = idx,
161
+ text = key,
162
+ dup_text = val[0],
163
+ score = val[1],
164
+ )
165
+ mk_all+=mk_tmp
166
+ return mk_all
167
+
168
+ with gr.Blocks(title="類似文検索demo", theme="bethecloud/storj_theme") as demo:
169
+ gr.Markdown("# 類似文検索demo")
170
+ gr.Markdown("""
171
+ * 三つモデルで、例文中に類似文を検索する
172
+ * minhash: 類似度として Jaccard 係数を考えるときに高速化する手法になります
173
+ * cased: 学習済みベクトル化モデルで、Cos 類似度方法です
174
+ * cased_train: 少しデータ(60件ぐらい)でモデル再学習して、Cos 類似度方法です
175
+ """)
176
+ choice = gr.Radio(choices = ["minhash", "cased", "cased_train"], label = "検索方法", value = "minhash")
177
+ input_dict={}
178
+ for idx, exp in enumerate(exp_list):
179
+ input_dict[f"{idx}"] = gr.Textbox(label=f"NEWS_{idx}",value=exp)
180
+ gen_btn = gr.Button("検索")
181
+ gr.Markdown("OUTPUT")
182
+ output = gr.Markdown(label="report")
183
+ # gen_btn.click(fn=dup_report, inputs=[input_dict, choice], outputs=output)
184
+ gen_btn.click(fn=dup_report, inputs=[choice] + [input_dict[i] for i in input_dict.keys()], outputs=output)
185
+
186
+ demo.launch(inline=False, share=True, debug=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ datasketch
3
+ vibrato
4
+ zstandard
5
+ wget
6
+ transformers
7
+ sentence-transformers