DrDominikDellermann nickmuchi commited on
Commit
5fb0891
β€’
0 Parent(s):

Duplicate from nickmuchi/Earnings-Call-Analysis-Whisperer

Browse files

Co-authored-by: Nicholas Muchinguri <nickmuchi@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
01_🏠_Home.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import os
3
+ import pandas as pd
4
+ import plotly_express as px
5
+ import nltk
6
+ import plotly.graph_objects as go
7
+ from optimum.onnxruntime import ORTModelForSequenceClassification
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
9
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
10
+ import streamlit as st
11
+ import en_core_web_lg
12
+
13
+ nltk.download('punkt')
14
+
15
+ from nltk import sent_tokenize
16
+
17
+ auth_token = os.environ.get("auth_token")
18
+
19
+ st.sidebar.header("Home")
20
+
21
+ asr_model_options = ['tiny.en','base.en','small.en']
22
+
23
+ asr_model_name = st.sidebar.selectbox("Whisper Model Options", options=asr_model_options, key="sbox")
24
+
25
+ st.markdown("## Earnings Call Analysis Whisperer")
26
+
27
+ twitter_link = """
28
+ [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
29
+ """
30
+
31
+ st.markdown(twitter_link)
32
+
33
+ st.markdown(
34
+ """
35
+ This app assists finance analysts with transcribing and analysis Earnings Calls by carrying out the following tasks:
36
+ - Transcribing earnings calls using Open AI's Whisper API, takes approx 3mins to transcribe a 1hr call less than 25mb in size.
37
+ - Analysing the sentiment of transcribed text using the quantized version of [FinBert-Tone](https://huggingface.co/nickmuchi/quantized-optimum-finbert-tone).
38
+ - Summarization of the call with [philschmid/flan-t5-base-samsum](https://huggingface.co/philschmid/flan-t5-base-samsum) model with entity extraction
39
+ - Question Answering Search engine powered by Langchain and [Sentence Transformers](https://huggingface.co/sentence-transformers/all-mpnet-base-v2).
40
+ - Knowledge Graph generation using [Babelscape/rebel-large](https://huggingface.co/Babelscape/rebel-large) model.
41
+
42
+ **πŸ‘‡ Enter a YouTube Earnings Call URL below and navigate to the sidebar tabs**
43
+
44
+ """
45
+ )
46
+
47
+ if 'sbox' not in st.session_state:
48
+ st.session_state.sbox = asr_model_name
49
+
50
+ if "earnings_passages" not in st.session_state:
51
+ st.session_state["earnings_passages"] = ''
52
+
53
+ if "sen_df" not in st.session_state:
54
+ st.session_state['sen_df'] = ''
55
+
56
+ url_input = st.text_input(
57
+ label="Enter YouTube URL, example below is McDonalds Earnings Call Q1 2023",
58
+ value="https://www.youtube.com/watch?v=4p6o5kkZYyA")
59
+
60
+ if 'url' not in st.session_state:
61
+ st.session_state['url'] = ""
62
+
63
+ st.session_state['url'] = url_input
64
+
65
+ st.markdown(
66
+ "<h3 style='text-align: center; color: red;'>OR</h3>",
67
+ unsafe_allow_html=True
68
+ )
69
+
70
+ upload_wav = st.file_uploader("Upload a .wav/.mp3/.mp4 audio file ",key="upload",type=['.wav','.mp3','.mp4'])
71
+
72
+ st.markdown("![visitors](https://visitor-badge.glitch.me/badge?page_id=nickmuchi.earnings-call-whisperer)")
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Earnings Call Analysis Whisperer
3
+ emoji: πŸ“ž
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: 01_🏠_Home.py
9
+ pinned: false
10
+ duplicated_from: nickmuchi/Earnings-Call-Analysis-Whisperer
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
download.wav ADDED
Binary file (36 kB). View file
 
functions.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import os
3
+ import random
4
+ import openai
5
+ import yt_dlp
6
+ from pytube import YouTube, extract
7
+ import pandas as pd
8
+ import plotly_express as px
9
+ import nltk
10
+ import plotly.graph_objects as go
11
+ from optimum.onnxruntime import ORTModelForSequenceClassification
12
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
13
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
14
+ import streamlit as st
15
+ import en_core_web_lg
16
+ import validators
17
+ import re
18
+ import itertools
19
+ import numpy as np
20
+ from bs4 import BeautifulSoup
21
+ import base64, time
22
+ from annotated_text import annotated_text
23
+ import pickle, math
24
+ import wikipedia
25
+ from pyvis.network import Network
26
+ import torch
27
+ from pydub import AudioSegment
28
+ from langchain.docstore.document import Document
29
+ from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
30
+ from langchain.vectorstores import FAISS
31
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
32
+ from langchain.chat_models import ChatOpenAI
33
+ from langchain.callbacks import StdOutCallbackHandler
34
+ from langchain.chains import ConversationalRetrievalChain, QAGenerationChain, LLMChain
35
+ from langchain.memory import ConversationBufferMemory
36
+ from langchain.chains.question_answering import load_qa_chain
37
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
38
+
39
+ from langchain.prompts.chat import (
40
+ ChatPromptTemplate,
41
+ SystemMessagePromptTemplate,
42
+ AIMessagePromptTemplate,
43
+ HumanMessagePromptTemplate,
44
+ )
45
+ from langchain.schema import (
46
+ AIMessage,
47
+ HumanMessage,
48
+ SystemMessage
49
+ )
50
+
51
+ from langchain.prompts import PromptTemplate
52
+
53
+ nltk.download('punkt')
54
+
55
+
56
+ from nltk import sent_tokenize
57
+
58
+ OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')
59
+ time_str = time.strftime("%d%m%Y-%H%M%S")
60
+ HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
61
+ margin-bottom: 2.5rem">{}</div> """
62
+
63
+ memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')
64
+
65
+
66
+ #Stuff Chain Type Prompt template
67
+
68
+ @st.cache_data
69
+ def load_prompt():
70
+
71
+ system_template="""Use only the following pieces of earnings context to answer the users question accurately.
72
+ Do not use any information not provided in the earnings context and remember you are a to speak like a finance expert.
73
+ If you don't know the answer, just say 'There is no relevant answer in the given earnings call transcript',
74
+ don't try to make up an answer.
75
+
76
+ ALWAYS return a "SOURCES" part in your answer.
77
+ The "SOURCES" part should be a reference to the source of the document from which you got your answer.
78
+
79
+ Remember, do not reference any information not given in the context.
80
+
81
+ If the answer is not available in the given context just say 'There is no relevant answer in the given earnings call transcript'
82
+
83
+ Follow the below format when answering:
84
+
85
+ Question: {question}
86
+ SOURCES: [xyz]
87
+
88
+ Begin!
89
+ ----------------
90
+ {context}"""
91
+
92
+ messages = [
93
+ SystemMessagePromptTemplate.from_template(system_template),
94
+ HumanMessagePromptTemplate.from_template("{question}")
95
+ ]
96
+ prompt = ChatPromptTemplate.from_messages(messages)
97
+
98
+ return prompt
99
+
100
+ ###################### Functions #######################################################################################
101
+
102
+ # @st.cache_data
103
+ # def get_yt_audio(url):
104
+ # temp_audio_file = os.path.join('output', 'audio')
105
+
106
+ # ydl_opts = {
107
+ # 'format': 'bestaudio/best',
108
+ # 'postprocessors': [{
109
+ # 'key': 'FFmpegExtractAudio',
110
+ # 'preferredcodec': 'mp3',
111
+ # 'preferredquality': '192',
112
+ # }],
113
+ # 'outtmpl': temp_audio_file,
114
+ # 'quiet': True,
115
+ # }
116
+
117
+ # with yt_dlp.YoutubeDL(ydl_opts) as ydl:
118
+
119
+ # info = ydl.extract_info(url, download=False)
120
+ # title = info.get('title', None)
121
+ # ydl.download([url])
122
+
123
+ # #with open(temp_audio_file+'.mp3', 'rb') as file:
124
+ # audio_file = os.path.join('output', 'audio.mp3')
125
+
126
+ # return audio_file, title
127
+
128
+ #load all required models and cache
129
+ @st.cache_resource
130
+ def load_models():
131
+
132
+ '''Load and cache all the models to be used'''
133
+ q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
134
+ ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
135
+ kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
136
+ kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
137
+ q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
138
+ ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
139
+ emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
140
+ sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
141
+ sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True)
142
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
143
+ cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2
144
+ sbert = SentenceTransformer('all-MiniLM-L6-v2')
145
+
146
+ return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert
147
+
148
+ @st.cache_resource
149
+ def get_spacy():
150
+ nlp = en_core_web_lg.load()
151
+ return nlp
152
+
153
+ nlp = get_spacy()
154
+
155
+ sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models()
156
+
157
+ @st.cache_data
158
+ def get_yt_audio(url):
159
+
160
+ '''Get YT video from given URL link'''
161
+ yt = YouTube(url)
162
+
163
+ title = yt.title
164
+
165
+ # Get the first available audio stream and download it
166
+ audio_stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
167
+
168
+ return audio_stream, title
169
+
170
+ @st.cache_data
171
+ def load_whisper_api(audio):
172
+
173
+ '''Transcribe YT audio to text using Open AI API'''
174
+ file = open(audio, "rb")
175
+ transcript = openai.Audio.translate("whisper-1", file)
176
+
177
+ return transcript
178
+
179
+ @st.cache_data
180
+ def load_asr_model(model_name):
181
+
182
+ '''Load the open source whisper model in cases where the API is not working'''
183
+ model = whisper.load_model(model_name)
184
+
185
+ return model
186
+
187
+ @st.cache_data
188
+ def inference(link, upload, _asr_model):
189
+ '''Convert Youtube video or Audio upload to text'''
190
+
191
+ try:
192
+
193
+ if validators.url(link):
194
+
195
+ st.info("`Downloading YT audio...`")
196
+
197
+ audio_file, title = get_yt_audio(link)
198
+
199
+ print(f'audio_file:{audio_file}')
200
+
201
+ st.session_state['audio'] = audio_file
202
+
203
+ print(f"audio_file_session_state:{st.session_state['audio'] }")
204
+
205
+ #Get size of audio file
206
+ audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1)
207
+
208
+ #Check if file is > 24mb, if not then use Whisper API
209
+ if audio_size <= 25:
210
+
211
+ st.info("`Transcribing YT audio...`")
212
+
213
+ #Use whisper API
214
+ results = load_whisper_api(st.session_state['audio'])['text']
215
+
216
+ else:
217
+
218
+ st.warning('File size larger than 24mb, applying chunking and transcription',icon="⚠️")
219
+
220
+ song = AudioSegment.from_file(st.session_state['audio'], format='mp4')
221
+
222
+ # PyDub handles time in milliseconds
223
+ twenty_minutes = 20 * 60 * 1000
224
+
225
+ chunks = song[::twenty_minutes]
226
+
227
+ transcriptions = []
228
+
229
+ video_id = extract.video_id(link)
230
+ for i, chunk in enumerate(chunks):
231
+ chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4')
232
+ transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text'])
233
+
234
+ results = ','.join(transcriptions)
235
+
236
+ st.info("`YT Video transcription process complete...`")
237
+
238
+ return results, title
239
+
240
+ elif _upload:
241
+
242
+ #Get size of audio file
243
+ audio_size = round(os.path.getsize(_upload)/(1024*1024),1)
244
+
245
+ #Check if file is > 24mb, if not then use Whisper API
246
+ if audio_size <= 25:
247
+
248
+ st.info("`Transcribing uploaded audio...`")
249
+
250
+ #Use whisper API
251
+ results = load_whisper_api(_upload)['text']
252
+
253
+ else:
254
+
255
+ st.write('File size larger than 24mb, applying chunking and transcription')
256
+
257
+ song = AudioSegment.from_file(_upload)
258
+
259
+ # PyDub handles time in milliseconds
260
+ twenty_minutes = 20 * 60 * 1000
261
+
262
+ chunks = song[::twenty_minutes]
263
+
264
+ transcriptions = []
265
+
266
+ st.info("`Transcribing uploaded audio...`")
267
+
268
+ for i, chunk in enumerate(chunks):
269
+ chunk.export(f'output/chunk_{i}.mp4', format='mp4')
270
+ transcriptions.append(load_whisper_api(f'output/chunk_{i}.mp4')['text'])
271
+
272
+ results = ','.join(transcriptions)
273
+
274
+ st.info("`Uploaded audio transcription process complete...`")
275
+
276
+ return results, "Transcribed Earnings Audio"
277
+
278
+ except Exception as e:
279
+
280
+ st.error(f'''Whisper API Error: {e},
281
+ Using Whisper module from GitHub, might take longer than expected''',icon="🚨")
282
+
283
+ results = _asr_model.transcribe(st.session_state['audio'], task='transcribe', language='en')
284
+
285
+ return results['text'], title
286
+
287
+ @st.cache_data
288
+ def clean_text(text):
289
+ '''Clean all text after inference'''
290
+
291
+ text = text.encode("ascii", "ignore").decode() # unicode
292
+ text = re.sub(r"https*\S+", " ", text) # url
293
+ text = re.sub(r"@\S+", " ", text) # mentions
294
+ text = re.sub(r"#\S+", " ", text) # hastags
295
+ text = re.sub(r"\s{2,}", " ", text) # over spaces
296
+
297
+ return text
298
+
299
+ @st.cache_data
300
+ def chunk_long_text(text,threshold,window_size=3,stride=2):
301
+ '''Preprocess text and chunk for sentiment analysis'''
302
+
303
+ #Convert cleaned text into sentences
304
+ sentences = sent_tokenize(text)
305
+ out = []
306
+
307
+ #Limit the length of each sentence to a threshold
308
+ for chunk in sentences:
309
+ if len(chunk.split()) < threshold:
310
+ out.append(chunk)
311
+ else:
312
+ words = chunk.split()
313
+ num = int(len(words)/threshold)
314
+ for i in range(0,num*threshold+1,threshold):
315
+ out.append(' '.join(words[i:threshold+i]))
316
+
317
+ passages = []
318
+
319
+ #Combine sentences into a window of size window_size
320
+ for paragraph in [out]:
321
+ for start_idx in range(0, len(paragraph), stride):
322
+ end_idx = min(start_idx+window_size, len(paragraph))
323
+ passages.append(" ".join(paragraph[start_idx:end_idx]))
324
+
325
+ return passages
326
+
327
+ @st.cache_data
328
+ def sentiment_pipe(earnings_text):
329
+ '''Determine the sentiment of the text'''
330
+
331
+ earnings_sentences = chunk_long_text(earnings_text,150,1,1)
332
+ earnings_sentiment = sent_pipe(earnings_sentences)
333
+
334
+ return earnings_sentiment, earnings_sentences
335
+
336
+ @st.cache_data
337
+ def chunk_and_preprocess_text(text, model_name= 'philschmid/flan-t5-base-samsum'):
338
+
339
+ '''Chunk and preprocess text for summarization'''
340
+
341
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
342
+ sentences = sent_tokenize(text)
343
+
344
+ # initialize
345
+ length = 0
346
+ chunk = ""
347
+ chunks = []
348
+ count = -1
349
+
350
+ for sentence in sentences:
351
+ count += 1
352
+ combined_length = len(tokenizer.tokenize(sentence)) + length # add the no. of sentence tokens to the length counter
353
+
354
+ if combined_length <= tokenizer.max_len_single_sentence: # if it doesn't exceed
355
+ chunk += sentence + " " # add the sentence to the chunk
356
+ length = combined_length # update the length counter
357
+
358
+ # if it is the last sentence
359
+ if count == len(sentences) - 1:
360
+ chunks.append(chunk) # save the chunk
361
+
362
+ else:
363
+ chunks.append(chunk) # save the chunk
364
+ # reset
365
+ length = 0
366
+ chunk = ""
367
+
368
+ # take care of the overflow sentence
369
+ chunk += sentence + " "
370
+ length = len(tokenizer.tokenize(sentence))
371
+
372
+ return chunks
373
+
374
+ @st.cache_data
375
+ def summarize_text(text_to_summarize,max_len,min_len):
376
+ '''Summarize text with HF model'''
377
+
378
+ summarized_text = sum_pipe(text_to_summarize,
379
+ max_length=max_len,
380
+ min_length=min_len,
381
+ do_sample=False,
382
+ early_stopping=True,
383
+ num_beams=4)
384
+ summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
385
+
386
+ return summarized_text
387
+
388
+ @st.cache_data
389
+ def get_all_entities_per_sentence(text):
390
+ doc = nlp(''.join(text))
391
+
392
+ sentences = list(doc.sents)
393
+
394
+ entities_all_sentences = []
395
+ for sentence in sentences:
396
+ entities_this_sentence = []
397
+
398
+ # SPACY ENTITIES
399
+ for entity in sentence.ents:
400
+ entities_this_sentence.append(str(entity))
401
+
402
+ # XLM ENTITIES
403
+ entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))]
404
+ for entity in entities_xlm:
405
+ entities_this_sentence.append(str(entity))
406
+
407
+ entities_all_sentences.append(entities_this_sentence)
408
+
409
+ return entities_all_sentences
410
+
411
+ @st.cache_data
412
+ def get_all_entities(text):
413
+ all_entities_per_sentence = get_all_entities_per_sentence(text)
414
+ return list(itertools.chain.from_iterable(all_entities_per_sentence))
415
+
416
+ @st.cache_data
417
+ def get_and_compare_entities(article_content,summary_output):
418
+
419
+ all_entities_per_sentence = get_all_entities_per_sentence(article_content)
420
+ entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
421
+
422
+ all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
423
+ entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
424
+
425
+ matched_entities = []
426
+ unmatched_entities = []
427
+ for entity in entities_summary:
428
+ if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
429
+ matched_entities.append(entity)
430
+ elif any(
431
+ np.inner(sbert.encode(entity, show_progress_bar=False),
432
+ sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for
433
+ art_entity in entities_article):
434
+ matched_entities.append(entity)
435
+ else:
436
+ unmatched_entities.append(entity)
437
+
438
+ matched_entities = list(dict.fromkeys(matched_entities))
439
+ unmatched_entities = list(dict.fromkeys(unmatched_entities))
440
+
441
+ matched_entities_to_remove = []
442
+ unmatched_entities_to_remove = []
443
+
444
+ for entity in matched_entities:
445
+ for substring_entity in matched_entities:
446
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
447
+ matched_entities_to_remove.append(entity)
448
+
449
+ for entity in unmatched_entities:
450
+ for substring_entity in unmatched_entities:
451
+ if entity != substring_entity and entity.lower() in substring_entity.lower():
452
+ unmatched_entities_to_remove.append(entity)
453
+
454
+ matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
455
+ unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))
456
+
457
+ for entity in matched_entities_to_remove:
458
+ matched_entities.remove(entity)
459
+ for entity in unmatched_entities_to_remove:
460
+ unmatched_entities.remove(entity)
461
+
462
+ return matched_entities, unmatched_entities
463
+
464
+ @st.cache_data
465
+ def highlight_entities(article_content,summary_output):
466
+
467
+ markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
468
+ markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
469
+ markdown_end = "</mark>"
470
+
471
+ matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
472
+
473
+ for entity in matched_entities:
474
+ summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output)
475
+
476
+ for entity in unmatched_entities:
477
+ summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output)
478
+
479
+ print("")
480
+ print("")
481
+
482
+ soup = BeautifulSoup(summary_output, features="html.parser")
483
+
484
+ return HTML_WRAPPER.format(soup)
485
+
486
+ def summary_downloader(raw_text):
487
+ '''Download the summary generated'''
488
+
489
+ b64 = base64.b64encode(raw_text.encode()).decode()
490
+ new_filename = "new_text_file_{}_.txt".format(time_str)
491
+ st.markdown("#### Download Summary as a File ###")
492
+ href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
493
+ st.markdown(href,unsafe_allow_html=True)
494
+
495
+ @st.cache_data
496
+ def generate_eval(raw_text, N, chunk):
497
+
498
+ # Generate N questions from context of chunk chars
499
+ # IN: text, N questions, chunk size to draw question from in the doc
500
+ # OUT: eval set as JSON list
501
+
502
+ # raw_text = ','.join(raw_text)
503
+
504
+ update = st.empty()
505
+ ques_update = st.empty()
506
+ update.info("`Generating sample questions ...`")
507
+ n = len(raw_text)
508
+ starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
509
+ sub_sequences = [raw_text[i:i+chunk] for i in starting_indices]
510
+ chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
511
+ eval_set = []
512
+
513
+ for i, b in enumerate(sub_sequences):
514
+ try:
515
+ qa = chain.run(b)
516
+ eval_set.append(qa)
517
+ ques_update.info(f"Creating Question: {i+1}")
518
+
519
+ except Exception as e:
520
+ print(e)
521
+ st.warning(f'Error in generating Question: {i+1}...', icon="⚠️")
522
+ continue
523
+
524
+ eval_set_full = list(itertools.chain.from_iterable(eval_set))
525
+
526
+ update.empty()
527
+ ques_update.empty()
528
+
529
+ return eval_set_full
530
+
531
+ @st.cache_resource
532
+ def gen_embeddings(embedding_model):
533
+
534
+ '''Generate embeddings for given model'''
535
+
536
+ if 'hkunlp' in embedding_model:
537
+
538
+ embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
539
+ query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
540
+ embed_instruction='Represent the Financial paragraph for retrieval: ')
541
+
542
+ else:
543
+
544
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
545
+
546
+ return embeddings
547
+
548
+ @st.cache_data
549
+ def process_corpus(corpus, title, embedding_model, chunk_size=1000, overlap=50):
550
+
551
+ '''Process text for Semantic Search'''
552
+
553
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
554
+
555
+ texts = text_splitter.split_text(corpus)
556
+
557
+ embeddings = gen_embeddings(embedding_model)
558
+
559
+ vectorstore = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))])
560
+
561
+ return vectorstore
562
+
563
+ def embed_text(query,_docsearch):
564
+
565
+ '''Embed text and generate semantic search scores'''
566
+
567
+ # llm = OpenAI(temperature=0)
568
+ chat_llm = ChatOpenAI(streaming=True,
569
+ model_name = 'gpt-4',
570
+ callbacks=[StdOutCallbackHandler()],
571
+ verbose=True,
572
+ temperature=0
573
+ )
574
+
575
+ # chain = RetrievalQA.from_chain_type(llm=chat_llm, chain_type="stuff",
576
+ # retriever=_docsearch.as_retriever(),
577
+ # return_source_documents=True)
578
+
579
+ question_generator = LLMChain(llm=chat_llm, prompt=CONDENSE_QUESTION_PROMPT)
580
+ doc_chain = load_qa_chain(llm=chat_llm,chain_type="stuff",prompt=load_prompt())
581
+ chain = ConversationalRetrievalChain(retriever=_docsearch.as_retriever(search_kwags={"k": 3}),
582
+ question_generator=question_generator,
583
+ combine_docs_chain=doc_chain,
584
+ memory=memory,
585
+ return_source_documents=True,
586
+ get_chat_history=lambda h :h)
587
+
588
+ answer = chain({"question": query})
589
+
590
+ return answer
591
+
592
+ @st.cache_data
593
+ def gen_sentiment(text):
594
+ '''Generate sentiment of given text'''
595
+ return sent_pipe(text)[0]['label']
596
+
597
+ @st.cache_data
598
+ def gen_annotated_text(df):
599
+ '''Generate annotated text'''
600
+
601
+ tag_list=[]
602
+ for row in df.itertuples():
603
+ label = row[2]
604
+ text = row[1]
605
+ if label == 'Positive':
606
+ tag_list.append((text,label,'#8fce00'))
607
+ elif label == 'Negative':
608
+ tag_list.append((text,label,'#f44336'))
609
+ else:
610
+ tag_list.append((text,label,'#000000'))
611
+
612
+ return tag_list
613
+
614
+
615
+ def display_df_as_table(model,top_k,score='score'):
616
+ '''Display the df with text and scores as a table'''
617
+
618
+ df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
619
+ df['Score'] = round(df['Score'],2)
620
+
621
+ return df
622
+
623
+
624
+ def make_spans(text,results):
625
+ results_list = []
626
+ for i in range(len(results)):
627
+ results_list.append(results[i]['label'])
628
+ facts_spans = []
629
+ facts_spans = list(zip(sent_tokenizer(text),results_list))
630
+ return facts_spans
631
+
632
+ ##Fiscal Sentiment by Sentence
633
+ def fin_ext(text):
634
+ results = remote_clx(sent_tokenizer(text))
635
+ return make_spans(text,results)
636
+
637
+ ## Knowledge Graphs code
638
+
639
+ @st.cache_data
640
+ def extract_relations_from_model_output(text):
641
+ relations = []
642
+ relation, subject, relation, object_ = '', '', '', ''
643
+ text = text.strip()
644
+ current = 'x'
645
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
646
+ for token in text_replaced.split():
647
+ if token == "<triplet>":
648
+ current = 't'
649
+ if relation != '':
650
+ relations.append({
651
+ 'head': subject.strip(),
652
+ 'type': relation.strip(),
653
+ 'tail': object_.strip()
654
+ })
655
+ relation = ''
656
+ subject = ''
657
+ elif token == "<subj>":
658
+ current = 's'
659
+ if relation != '':
660
+ relations.append({
661
+ 'head': subject.strip(),
662
+ 'type': relation.strip(),
663
+ 'tail': object_.strip()
664
+ })
665
+ object_ = ''
666
+ elif token == "<obj>":
667
+ current = 'o'
668
+ relation = ''
669
+ else:
670
+ if current == 't':
671
+ subject += ' ' + token
672
+ elif current == 's':
673
+ object_ += ' ' + token
674
+ elif current == 'o':
675
+ relation += ' ' + token
676
+ if subject != '' and relation != '' and object_ != '':
677
+ relations.append({
678
+ 'head': subject.strip(),
679
+ 'type': relation.strip(),
680
+ 'tail': object_.strip()
681
+ })
682
+ return relations
683
+
684
+ def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
685
+ article_publish_date=None, verbose=False):
686
+ # tokenize whole text
687
+ inputs = tokenizer([text], return_tensors="pt")
688
+
689
+ # compute span boundaries
690
+ num_tokens = len(inputs["input_ids"][0])
691
+ if verbose:
692
+ print(f"Input has {num_tokens} tokens")
693
+ num_spans = math.ceil(num_tokens / span_length)
694
+ if verbose:
695
+ print(f"Input has {num_spans} spans")
696
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
697
+ max(num_spans - 1, 1))
698
+ spans_boundaries = []
699
+ start = 0
700
+ for i in range(num_spans):
701
+ spans_boundaries.append([start + span_length * i,
702
+ start + span_length * (i + 1)])
703
+ start -= overlap
704
+ if verbose:
705
+ print(f"Span boundaries are {spans_boundaries}")
706
+
707
+ # transform input with spans
708
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
709
+ for boundary in spans_boundaries]
710
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
711
+ for boundary in spans_boundaries]
712
+ inputs = {
713
+ "input_ids": torch.stack(tensor_ids),
714
+ "attention_mask": torch.stack(tensor_masks)
715
+ }
716
+
717
+ # generate relations
718
+ num_return_sequences = 3
719
+ gen_kwargs = {
720
+ "max_length": 256,
721
+ "length_penalty": 0,
722
+ "num_beams": 3,
723
+ "num_return_sequences": num_return_sequences
724
+ }
725
+ generated_tokens = model.generate(
726
+ **inputs,
727
+ **gen_kwargs,
728
+ )
729
+
730
+ # decode relations
731
+ decoded_preds = tokenizer.batch_decode(generated_tokens,
732
+ skip_special_tokens=False)
733
+
734
+ # create kb
735
+ kb = KB()
736
+ i = 0
737
+ for sentence_pred in decoded_preds:
738
+ current_span_index = i // num_return_sequences
739
+ relations = extract_relations_from_model_output(sentence_pred)
740
+ for relation in relations:
741
+ relation["meta"] = {
742
+ article_url: {
743
+ "spans": [spans_boundaries[current_span_index]]
744
+ }
745
+ }
746
+ kb.add_relation(relation, article_title, article_publish_date)
747
+ i += 1
748
+
749
+ return kb
750
+
751
+ def get_article(url):
752
+ article = Article(url)
753
+ article.download()
754
+ article.parse()
755
+ return article
756
+
757
+ def from_url_to_kb(url, model, tokenizer):
758
+ article = get_article(url)
759
+ config = {
760
+ "article_title": article.title,
761
+ "article_publish_date": article.publish_date
762
+ }
763
+ kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
764
+ return kb
765
+
766
+ def get_news_links(query, lang="en", region="US", pages=1):
767
+ googlenews = GoogleNews(lang=lang, region=region)
768
+ googlenews.search(query)
769
+ all_urls = []
770
+ for page in range(pages):
771
+ googlenews.get_page(page)
772
+ all_urls += googlenews.get_links()
773
+ return list(set(all_urls))
774
+
775
+ def from_urls_to_kb(urls, model, tokenizer, verbose=False):
776
+ kb = KB()
777
+ if verbose:
778
+ print(f"{len(urls)} links to visit")
779
+ for url in urls:
780
+ if verbose:
781
+ print(f"Visiting {url}...")
782
+ try:
783
+ kb_url = from_url_to_kb(url, model, tokenizer)
784
+ kb.merge_with_kb(kb_url)
785
+ except ArticleException:
786
+ if verbose:
787
+ print(f" Couldn't download article at url {url}")
788
+ return kb
789
+
790
+ def save_network_html(kb, filename="network.html"):
791
+ # create network
792
+ net = Network(directed=True, width="700px", height="700px")
793
+
794
+ # nodes
795
+ color_entity = "#00FF00"
796
+ for e in kb.entities:
797
+ net.add_node(e, shape="circle", color=color_entity)
798
+
799
+ # edges
800
+ for r in kb.relations:
801
+ net.add_edge(r["head"], r["tail"],
802
+ title=r["type"], label=r["type"])
803
+
804
+ # save network
805
+ net.repulsion(
806
+ node_distance=200,
807
+ central_gravity=0.2,
808
+ spring_length=200,
809
+ spring_strength=0.05,
810
+ damping=0.09
811
+ )
812
+ net.set_edge_smooth('dynamic')
813
+ net.show(filename)
814
+
815
+ def save_kb(kb, filename):
816
+ with open(filename, "wb") as f:
817
+ pickle.dump(kb, f)
818
+
819
+ class CustomUnpickler(pickle.Unpickler):
820
+ def find_class(self, module, name):
821
+ if name == 'KB':
822
+ return KB
823
+ return super().find_class(module, name)
824
+
825
+ def load_kb(filename):
826
+ res = None
827
+ with open(filename, "rb") as f:
828
+ res = CustomUnpickler(f).load()
829
+ return res
830
+
831
+ class KB():
832
+ def __init__(self):
833
+ self.entities = {} # { entity_title: {...} }
834
+ self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
835
+ # meta: { article_url: { spans: [...] } } ]
836
+ self.sources = {} # { article_url: {...} }
837
+
838
+ def merge_with_kb(self, kb2):
839
+ for r in kb2.relations:
840
+ article_url = list(r["meta"].keys())[0]
841
+ source_data = kb2.sources[article_url]
842
+ self.add_relation(r, source_data["article_title"],
843
+ source_data["article_publish_date"])
844
+
845
+ def are_relations_equal(self, r1, r2):
846
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
847
+
848
+ def exists_relation(self, r1):
849
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
850
+
851
+ def merge_relations(self, r2):
852
+ r1 = [r for r in self.relations
853
+ if self.are_relations_equal(r2, r)][0]
854
+
855
+ # if different article
856
+ article_url = list(r2["meta"].keys())[0]
857
+ if article_url not in r1["meta"]:
858
+ r1["meta"][article_url] = r2["meta"][article_url]
859
+
860
+ # if existing article
861
+ else:
862
+ spans_to_add = [span for span in r2["meta"][article_url]["spans"]
863
+ if span not in r1["meta"][article_url]["spans"]]
864
+ r1["meta"][article_url]["spans"] += spans_to_add
865
+
866
+ def get_wikipedia_data(self, candidate_entity):
867
+ try:
868
+ page = wikipedia.page(candidate_entity, auto_suggest=False)
869
+ entity_data = {
870
+ "title": page.title,
871
+ "url": page.url,
872
+ "summary": page.summary
873
+ }
874
+ return entity_data
875
+ except:
876
+ return None
877
+
878
+ def add_entity(self, e):
879
+ self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
880
+
881
+ def add_relation(self, r, article_title, article_publish_date):
882
+ # check on wikipedia
883
+ candidate_entities = [r["head"], r["tail"]]
884
+ entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
885
+
886
+ # if one entity does not exist, stop
887
+ if any(ent is None for ent in entities):
888
+ return
889
+
890
+ # manage new entities
891
+ for e in entities:
892
+ self.add_entity(e)
893
+
894
+ # rename relation entities with their wikipedia titles
895
+ r["head"] = entities[0]["title"]
896
+ r["tail"] = entities[1]["title"]
897
+
898
+ # add source if not in kb
899
+ article_url = list(r["meta"].keys())[0]
900
+ if article_url not in self.sources:
901
+ self.sources[article_url] = {
902
+ "article_title": article_title,
903
+ "article_publish_date": article_publish_date
904
+ }
905
+
906
+ # manage new relation
907
+ if not self.exists_relation(r):
908
+ self.relations.append(r)
909
+ else:
910
+ self.merge_relations(r)
911
+
912
+ def get_textual_representation(self):
913
+ res = ""
914
+ res += "### Entities\n"
915
+ for e in self.entities.items():
916
+ # shorten summary
917
+ e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
918
+ res += f"- {e_temp}\n"
919
+ res += "\n"
920
+ res += "### Relations\n"
921
+ for r in self.relations:
922
+ res += f"- {r}\n"
923
+ res += "\n"
924
+ res += "### Sources\n"
925
+ for s in self.sources.items():
926
+ res += f"- {s}\n"
927
+ return res
928
+
929
+ def save_network_html(kb, filename="network.html"):
930
+ # create network
931
+ net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")
932
+
933
+ # nodes
934
+ color_entity = "#00FF00"
935
+ for e in kb.entities:
936
+ net.add_node(e, shape="circle", color=color_entity)
937
+
938
+ # edges
939
+ for r in kb.relations:
940
+ net.add_edge(r["head"], r["tail"],
941
+ title=r["type"], label=r["type"])
942
+
943
+ # save network
944
+ net.repulsion(
945
+ node_distance=200,
946
+ central_gravity=0.2,
947
+ spring_length=200,
948
+ spring_strength=0.05,
949
+ damping=0.09
950
+ )
951
+ net.set_edge_smooth('dynamic')
952
+ net.show(filename)
output/audio.txt ADDED
File without changes
pages/1_Earnings_Sentiment_Analysis_πŸ“ˆ_.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly_express as px
4
+ import plotly.graph_objects as go
5
+ from functions import *
6
+ import validators
7
+ import textwrap
8
+
9
+ #st.set_page_config(page_title="Earnings Sentiment Analysis", page_icon="πŸ“ˆ")
10
+ st.sidebar.header("Sentiment Analysis")
11
+ st.markdown("## Earnings Sentiment Analysis with FinBert-Tone")
12
+
13
+ #load whisper model
14
+ asr_model = load_asr_model(st.session_state.sbox)
15
+
16
+ if "url" not in st.session_state:
17
+ st.session_state.url = ''
18
+
19
+ if "title" not in st.session_state:
20
+ st.session_state.title = ''
21
+
22
+ try:
23
+
24
+ if st.session_state['url'] is not None or st.session_state['upload'] is not None:
25
+
26
+ results, title = inference(st.session_state.url,st.session_state.upload,asr_model)
27
+
28
+ print(f'results, page1: {results}')
29
+
30
+ st.subheader(title)
31
+
32
+ earnings_passages = clean_text(results)
33
+
34
+ st.session_state['earnings_passages'] = earnings_passages
35
+
36
+ st.session_state['title'] = title
37
+
38
+ earnings_sentiment, earnings_sentences = sentiment_pipe(earnings_passages)
39
+
40
+ with st.expander("See Transcribed Earnings Text"):
41
+ st.write(f"Number of Sentences: {len(earnings_sentences)}")
42
+
43
+ st.write(st.session_state['earnings_passages'])
44
+
45
+
46
+ ## Save to a dataframe for ease of visualization
47
+ sen_df = pd.DataFrame(earnings_sentiment)
48
+ sen_df['text'] = earnings_sentences
49
+ grouped = pd.DataFrame(sen_df['label'].value_counts()).reset_index()
50
+ grouped.columns = ['sentiment','count']
51
+
52
+ st.session_state['sen_df'] = sen_df
53
+
54
+ # Display number of positive, negative and neutral sentiments
55
+ fig = px.bar(grouped, x='sentiment', y='count', color='sentiment', color_discrete_map={"Negative":"firebrick","Neutral":\
56
+ "navajowhite","Positive":"darkgreen"},\
57
+ title='Earnings Sentiment')
58
+
59
+ fig.update_layout(
60
+ showlegend=False,
61
+ autosize=True,
62
+ margin=dict(
63
+ l=25,
64
+ r=25,
65
+ b=25,
66
+ t=50,
67
+ pad=2
68
+ )
69
+ )
70
+
71
+
72
+ st.plotly_chart(fig)
73
+
74
+ ## Display sentiment score
75
+ pos_perc = grouped[grouped['sentiment']=='Positive']['count'].iloc[0]*100/sen_df.shape[0]
76
+ neg_perc = grouped[grouped['sentiment']=='Negative']['count'].iloc[0]*100/sen_df.shape[0]
77
+ neu_perc = grouped[grouped['sentiment']=='Neutral']['count'].iloc[0]*100/sen_df.shape[0]
78
+
79
+ sentiment_score = neu_perc+pos_perc-neg_perc
80
+
81
+ fig_1 = go.Figure()
82
+
83
+ fig_1.add_trace(go.Indicator(
84
+ mode = "delta",
85
+ value = sentiment_score,
86
+ domain = {'row': 1, 'column': 1}))
87
+
88
+ fig_1.update_layout(
89
+ template = {'data' : {'indicator': [{
90
+ 'title': {'text': "Sentiment Score"},
91
+ 'mode' : "number+delta+gauge",
92
+ 'delta' : {'reference': 50}}]
93
+ }},
94
+ autosize=False,
95
+ width=250,
96
+ height=250,
97
+ margin=dict(
98
+ l=5,
99
+ r=5,
100
+ b=5,
101
+ pad=2
102
+ )
103
+ )
104
+
105
+ with st.sidebar:
106
+
107
+ st.plotly_chart(fig_1)
108
+
109
+ hd = sen_df.text.apply(lambda txt: '<br>'.join(textwrap.wrap(txt, width=70)))
110
+ ## Display negative sentence locations
111
+ fig = px.scatter(sen_df, y='label', color='label', size='score', hover_data=[hd], color_discrete_map={"Negative":"firebrick","Neutral":"navajowhite","Positive":"darkgreen"}, title='Sentiment Score Distribution')
112
+
113
+
114
+ fig.update_layout(
115
+ showlegend=False,
116
+ autosize=True,
117
+ width=800,
118
+ height=500,
119
+ margin=dict(
120
+ b=5,
121
+ t=50,
122
+ pad=4
123
+ )
124
+ )
125
+
126
+ st.plotly_chart(fig)
127
+
128
+ else:
129
+
130
+ st.write("No YouTube URL or file upload detected")
131
+
132
+ except (AttributeError, TypeError):
133
+
134
+ st.write("No YouTube URL or file upload detected")
pages/2_Earnings_Summarization_πŸ“–_.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from functions import *
3
+
4
+ # st.set_page_config(page_title="Earnings Summarization", page_icon="πŸ“–")
5
+ st.sidebar.header("Earnings Summarization")
6
+ st.markdown("## Earnings Summarization with Flan-T5-Base-SamSun")
7
+
8
+ max_len= st.slider("Maximum length of the summarized text",min_value=70,max_value=200,step=10,value=100)
9
+ min_len= st.slider("Minimum length of the summarized text",min_value=20,max_value=200,step=10)
10
+
11
+ st.markdown("####")
12
+
13
+ st.subheader("Summarized Earnings Call with matched Entities")
14
+
15
+ if "earnings_passages" not in st.session_state:
16
+ st.session_state["earnings_passages"] = ''
17
+
18
+ if st.session_state['earnings_passages']:
19
+
20
+ with st.spinner("Summarizing and matching entities, this takes a few seconds..."):
21
+
22
+ try:
23
+ text_to_summarize = chunk_and_preprocess_text(st.session_state['earnings_passages'])
24
+ print(text_to_summarize)
25
+ summarized_text = summarize_text(text_to_summarize,max_len=max_len,min_len=min_len)
26
+
27
+
28
+ except IndexError:
29
+ try:
30
+
31
+ text_to_summarize = chunk_and_preprocess_text(st.session_state['earnings_passages'])
32
+ summarized_text = summarize_text(text_to_summarize,max_len=max_len,min_len=min_len)
33
+
34
+
35
+ except IndexError:
36
+
37
+ text_to_summarize = chunk_and_preprocess_text(st.session_state['earnings_passages'])
38
+ summarized_text = summarize_text(text_to_summarize,max_len=max_len,min_len=min_len)
39
+
40
+ entity_match_html = highlight_entities(text_to_summarize,summarized_text)
41
+ st.markdown("####")
42
+
43
+ with st.expander(label='Summarized Earnings Call',expanded=True):
44
+ st.write(entity_match_html, unsafe_allow_html=True)
45
+
46
+ st.markdown("####")
47
+
48
+ summary_downloader(summarized_text)
49
+
50
+ else:
51
+ st.write("No text to summarize detected, please ensure you have entered the YouTube URL on the Sentiment Analysis page")
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from functions import *
3
+ from langchain.chains import QAGenerationChain
4
+ import itertools
5
+
6
+
7
+ st.set_page_config(page_title="Earnings Question/Answering", page_icon="πŸ”Ž")
8
+
9
+ st.sidebar.header("Semantic Search")
10
+
11
+ st.markdown("Earnings Semantic Search with LangChain, OpenAI & SBert")
12
+
13
+ st.markdown(
14
+ """
15
+ <style>
16
+
17
+ #MainMenu {visibility: hidden;
18
+ # }
19
+ footer {visibility: hidden;
20
+ }
21
+ .css-card {
22
+ border-radius: 0px;
23
+ padding: 30px 10px 10px 10px;
24
+ background-color: black;
25
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
26
+ margin-bottom: 10px;
27
+ font-family: "IBM Plex Sans", sans-serif;
28
+ }
29
+
30
+ .card-tag {
31
+ border-radius: 0px;
32
+ padding: 1px 5px 1px 5px;
33
+ margin-bottom: 10px;
34
+ position: absolute;
35
+ left: 0px;
36
+ top: 0px;
37
+ font-size: 0.6rem;
38
+ font-family: "IBM Plex Sans", sans-serif;
39
+ color: white;
40
+ background-color: green;
41
+ }
42
+
43
+ .css-zt5igj {left:0;
44
+ }
45
+
46
+ span.css-10trblm {margin-left:0;
47
+ }
48
+
49
+ div.css-1kyxreq {margin-top: -40px;
50
+ }
51
+
52
+
53
+
54
+
55
+
56
+ </style>
57
+ """,
58
+ unsafe_allow_html=True,
59
+ )
60
+
61
+ bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
62
+ 'instructor-base': 'hkunlp/instructor-base'}
63
+
64
+ search_input = st.text_input(
65
+ label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
66
+
67
+ sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
68
+
69
+ st.sidebar.markdown('Earnings QnA Generator')
70
+
71
+ chunk_size = 1000
72
+ overlap_size = 50
73
+
74
+ try:
75
+
76
+ if search_input:
77
+
78
+ if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
79
+
80
+ ## Save to a dataframe for ease of visualization
81
+ sen_df = st.session_state['sen_df']
82
+
83
+ title = st.session_state['title']
84
+
85
+ earnings_text = st.session_state['earnings_passages']
86
+
87
+ print(f'earnings_to_be_embedded:{earnings_text}')
88
+
89
+ st.session_state.eval_set = generate_eval(
90
+ earnings_text, 10, 3000)
91
+
92
+ # Display the question-answer pairs in the sidebar with smaller text
93
+ for i, qa_pair in enumerate(st.session_state.eval_set):
94
+ st.sidebar.markdown(
95
+ f"""
96
+ <div class="css-card">
97
+ <span class="card-tag">Question {i + 1}</span>
98
+ <p style="font-size: 12px;">{qa_pair['question']}</p>
99
+ <p style="font-size: 12px;">{qa_pair['answer']}</p>
100
+ </div>
101
+ """,
102
+ unsafe_allow_html=True,
103
+ )
104
+
105
+ embedding_model = bi_enc_dict[sbert_model_name]
106
+
107
+ with st.spinner(
108
+ text=f"Loading {embedding_model} embedding model and Generating Response..."
109
+ ):
110
+
111
+ docsearch = process_corpus(earnings_text,title, embedding_model)
112
+
113
+ result = embed_text(search_input,docsearch)
114
+
115
+
116
+ references = [doc.page_content for doc in result['source_documents']]
117
+
118
+ answer = result['answer']
119
+
120
+ sentiment_label = gen_sentiment(answer)
121
+
122
+ ##### Sematic Search #####
123
+
124
+ df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
125
+
126
+
127
+ text_annotations = gen_annotated_text(df)[0]
128
+
129
+ with st.expander(label='Query Result', expanded=True):
130
+ annotated_text(text_annotations)
131
+
132
+ with st.expander(label='References from Corpus used to Generate Result'):
133
+ for ref in references:
134
+ st.write(ref)
135
+
136
+ else:
137
+
138
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
139
+
140
+ else:
141
+
142
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
143
+
144
+ except RuntimeError:
145
+
146
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
147
+
148
+
pages/4_Earnings_Knowledge_Graph_πŸ“ˆ_.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pyvis.network import Network
3
+ from functions import *
4
+ import streamlit.components.v1 as components
5
+ import pickle, math
6
+
7
+ st.set_page_config(page_title="Earnings Knowledge Graph", page_icon="πŸ“ˆ")
8
+ st.sidebar.header("Knowledge Graph")
9
+ st.markdown("## Earnings Knowledge Graph")
10
+
11
+ filename = "earnings_network.html"
12
+
13
+ if "earnings_passages" in st.session_state:
14
+
15
+ with st.spinner(text='Loading Babelscape/rebel-large which can take a few minutes to generate the graph..'):
16
+
17
+ st.session_state.kb_text = from_text_to_kb(st.session_state['earnings_passages'], kg_model, kg_tokenizer, "", verbose=True)
18
+ save_network_html(st.session_state.kb_text, filename=filename)
19
+ st.session_state.kb_chart = filename
20
+
21
+ with st.container():
22
+ st.subheader("Generated Knowledge Graph")
23
+ st.markdown("*You can interact with the graph and zoom.*")
24
+ html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read()
25
+ components.html(html_source_code, width=700, height=700)
26
+ st.markdown(st.session_state.kb_text)
27
+
28
+ else:
29
+
30
+ st.write('No earnings text detected, please regenerate from Home page..')
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/openai/whisper.git
3
+ sentence-transformers
4
+ transformers
5
+ InstructorEmbedding
6
+ optimum[onnxruntime]
7
+ yt-dlp
8
+ pydub
9
+ validators
10
+ nltk==3.7
11
+ plotly
12
+ plotly-express
13
+ spacy
14
+ spacy_streamlit
15
+ st-annotated-text
16
+ en_core_web_lg @ https://huggingface.co/spacy/en_core_web_lg/resolve/main/en_core_web_lg-any-py3-none-any.whl
17
+ bs4==0.0.1
18
+ wikipedia
19
+ pyvis
20
+ langchain==0.0.225
21
+ openai
22
+ faiss-cpu
23
+ altair<5
24
+ git+https://github.com/oncename/pytube.git
sentence-transformers/.DS_Store ADDED
Binary file (8.2 kB). View file
 
sentence-transformers/NOTICE.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ -------------------------------------------------------------------------------
2
+ Copyright 2019
3
+ Ubiquitous Knowledge Processing (UKP) Lab
4
+ Technische UniversitΓ€t Darmstadt
5
+ -------------------------------------------------------------------------------
sentence-transformers/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--- BADGES: START --->
2
+ [![GitHub - License](https://img.shields.io/github/license/UKPLab/sentence-transformers?logo=github&style=flat&color=green)][#github-license]
3
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/sentence-transformers?logo=pypi&style=flat&color=blue)][#pypi-package]
4
+ [![PyPI - Package Version](https://img.shields.io/pypi/v/sentence-transformers?logo=pypi&style=flat&color=orange)][#pypi-package]
5
+ [![Conda - Platform](https://img.shields.io/conda/pn/conda-forge/sentence-transformers?logo=anaconda&style=flat)][#conda-forge-package]
6
+ [![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/sentence-transformers?logo=anaconda&style=flat&color=orange)][#conda-forge-package]
7
+ [![Docs - GitHub.io](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=sentence-transformers)][#docs-package]
8
+ <!---
9
+ [![PyPI - Downloads](https://img.shields.io/pypi/dm/sentence-transformers?logo=pypi&style=flat&color=green)][#pypi-package]
10
+ [![Conda](https://img.shields.io/conda/dn/conda-forge/sentence-transformers?logo=anaconda)][#conda-forge-package]
11
+ --->
12
+
13
+ [#github-license]: https://github.com/UKPLab/sentence-transformers/blob/master/LICENSE
14
+ [#pypi-package]: https://pypi.org/project/sentence-transformers/
15
+ [#conda-forge-package]: https://anaconda.org/conda-forge/sentence-transformers
16
+ [#docs-package]: https://www.sbert.net/
17
+ <!--- BADGES: END --->
18
+
19
+ # Sentence Transformers: Multilingual Sentence, Paragraph, and Image Embeddings using BERT & Co.
20
+
21
+ This framework provides an easy method to compute dense vector representations for **sentences**, **paragraphs**, and **images**. The models are based on transformer networks like BERT / RoBERTa / XLM-RoBERTa etc. and achieve state-of-the-art performance in various task. Text is embedding in vector space such that similar text is close and can efficiently be found using cosine similarity.
22
+
23
+ We provide an increasing number of **[state-of-the-art pretrained models](https://www.sbert.net/docs/pretrained_models.html)** for more than 100 languages, fine-tuned for various use-cases.
24
+
25
+ Further, this framework allows an easy **[fine-tuning of custom embeddings models](https://www.sbert.net/docs/training/overview.html)**, to achieve maximal performance on your specific task.
26
+
27
+ For the **full documentation**, see **[www.SBERT.net](https://www.sbert.net)**.
28
+
29
+ The following publications are integrated in this framework:
30
+
31
+ - [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019)
32
+ - [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020)
33
+ - [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (NAACL 2021)
34
+ - [The Curse of Dense Low-Dimensional Information Retrieval for Large Index Sizes](https://arxiv.org/abs/2012.14210) (arXiv 2020)
35
+ - [TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning](https://arxiv.org/abs/2104.06979) (arXiv 2021)
36
+ - [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663) (arXiv 2021)
37
+
38
+ ## Installation
39
+
40
+ We recommend **Python 3.6** or higher, **[PyTorch 1.6.0](https://pytorch.org/get-started/locally/)** or higher and **[transformers v4.6.0](https://github.com/huggingface/transformers)** or higher. The code does **not** work with Python 2.7.
41
+
42
+ **Install with pip**
43
+
44
+ Install the *sentence-transformers* with `pip`:
45
+
46
+ ```
47
+ pip install -U sentence-transformers
48
+ ```
49
+
50
+ **Install with conda**
51
+
52
+ You can install the *sentence-transformers* with `conda`:
53
+
54
+ ```
55
+ conda install -c conda-forge sentence-transformers
56
+ ```
57
+
58
+ **Install from sources**
59
+
60
+ Alternatively, you can also clone the latest version from the [repository](https://github.com/UKPLab/sentence-transformers) and install it directly from the source code:
61
+
62
+ ````
63
+ pip install -e .
64
+ ````
65
+
66
+ **PyTorch with CUDA**
67
+
68
+ If you want to use a GPU / CUDA, you must install PyTorch with the matching CUDA Version. Follow
69
+ [PyTorch - Get Started](https://pytorch.org/get-started/locally/) for further details how to install PyTorch.
70
+
71
+ ## Getting Started
72
+
73
+ See [Quickstart](https://www.sbert.net/docs/quickstart.html) in our documenation.
74
+
75
+ [This example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/computing-embeddings/computing_embeddings.py) shows you how to use an already trained Sentence Transformer model to embed sentences for another task.
76
+
77
+ First download a pretrained model.
78
+
79
+ ````python
80
+ from sentence_transformers import SentenceTransformer
81
+ model = SentenceTransformer('all-MiniLM-L6-v2')
82
+ ````
83
+
84
+ Then provide some sentences to the model.
85
+
86
+ ````python
87
+ sentences = ['This framework generates embeddings for each input sentence',
88
+ 'Sentences are passed as a list of string.',
89
+ 'The quick brown fox jumps over the lazy dog.']
90
+ sentence_embeddings = model.encode(sentences)
91
+ ````
92
+
93
+ And that's it already. We now have a list of numpy arrays with the embeddings.
94
+
95
+ ````python
96
+ for sentence, embedding in zip(sentences, sentence_embeddings):
97
+ print("Sentence:", sentence)
98
+ print("Embedding:", embedding)
99
+ print("")
100
+ ````
101
+
102
+ ## Pre-Trained Models
103
+
104
+ We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`.
105
+
106
+ [Β» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
107
+
108
+ ## Training
109
+
110
+ This framework allows you to fine-tune your own sentence embedding methods, so that you get task-specific sentence embeddings. You have various options to choose from in order to get perfect sentence embeddings for your specific task.
111
+
112
+ See [Training Overview](https://www.sbert.net/docs/training/overview.html) for an introduction how to train your own embedding models. We provide [various examples](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) how to train models on various datasets.
113
+
114
+ Some highlights are:
115
+ - Support of various transformer networks including BERT, RoBERTa, XLM-R, DistilBERT, Electra, BART, ...
116
+ - Multi-Lingual and multi-task learning
117
+ - Evaluation during training to find optimal model
118
+ - [10+ loss-functions](https://www.sbert.net/docs/package_reference/losses.html) allowing to tune models specifically for semantic search, paraphrase mining, semantic similarity comparison, clustering, triplet loss, contrastive loss.
119
+
120
+ ## Performance
121
+
122
+ Our models are evaluated extensively on 15+ datasets including challening domains like Tweets, Reddit, emails. They achieve by far the **best performance** from all available sentence embedding methods. Further, we provide several **smaller models** that are **optimized for speed**.
123
+
124
+ [Β» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
125
+
126
+ ## Application Examples
127
+
128
+ You can use this framework for:
129
+
130
+ - [Computing Sentence Embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html)
131
+ - [Semantic Textual Similarity](https://www.sbert.net/docs/usage/semantic_textual_similarity.html)
132
+ - [Clustering](https://www.sbert.net/examples/applications/clustering/README.html)
133
+ - [Paraphrase Mining](https://www.sbert.net/examples/applications/paraphrase-mining/README.html)
134
+ - [Translated Sentence Mining](https://www.sbert.net/examples/applications/parallel-sentence-mining/README.html)
135
+ - [Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html)
136
+ - [Retrieve & Re-Rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html)
137
+ - [Text Summarization](https://www.sbert.net/examples/applications/text-summarization/README.html)
138
+ - [Multilingual Image Search, Clustering & Duplicate Detection](https://www.sbert.net/examples/applications/image-search/README.html)
139
+
140
+ and many more use-cases.
141
+
142
+ For all examples, see [examples/applications](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications).
143
+
144
+ ## Citing & Authors
145
+
146
+ If you find this repository helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
147
+
148
+ ```bibtex
149
+ @inproceedings{reimers-2019-sentence-bert,
150
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
151
+ author = "Reimers, Nils and Gurevych, Iryna",
152
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
153
+ month = "11",
154
+ year = "2019",
155
+ publisher = "Association for Computational Linguistics",
156
+ url = "https://arxiv.org/abs/1908.10084",
157
+ }
158
+ ```
159
+
160
+ If you use one of the multilingual models, feel free to cite our publication [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813):
161
+
162
+ ```bibtex
163
+ @inproceedings{reimers-2020-multilingual-sentence-bert,
164
+ title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
165
+ author = "Reimers, Nils and Gurevych, Iryna",
166
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
167
+ month = "11",
168
+ year = "2020",
169
+ publisher = "Association for Computational Linguistics",
170
+ url = "https://arxiv.org/abs/2004.09813",
171
+ }
172
+ ```
173
+
174
+ Please have a look at [Publications](https://www.sbert.net/docs/publications.html) for our different publications that are integrated into SentenceTransformers.
175
+
176
+ Contact person: [Nils Reimers](https://www.nils-reimers.de), [info@nils-reimers.de](mailto:info@nils-reimers.de)
177
+
178
+ https://www.ukp.tu-darmstadt.de/
179
+
180
+ Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
181
+
182
+ > This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.
sentence-transformers/eval_beir.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+ import argparse
9
+ import torch
10
+ import logging
11
+ import json
12
+ import numpy as np
13
+ import os
14
+
15
+ import src.slurm
16
+ import src.contriever
17
+ import src.beir_utils
18
+ import src.utils
19
+ import src.dist_utils
20
+ import src.contriever
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def main(args):
26
+
27
+ src.slurm.init_distributed_mode(args)
28
+ src.slurm.init_signal_handler()
29
+
30
+ os.makedirs(args.output_dir, exist_ok=True)
31
+
32
+ logger = src.utils.init_logger(args)
33
+
34
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
35
+ model = model.cuda()
36
+ model.eval()
37
+ query_encoder = model
38
+ doc_encoder = model
39
+
40
+ logger.info("Start indexing")
41
+
42
+ metrics = src.beir_utils.evaluate_model(
43
+ query_encoder=query_encoder,
44
+ doc_encoder=doc_encoder,
45
+ tokenizer=tokenizer,
46
+ dataset=args.dataset,
47
+ batch_size=args.per_gpu_batch_size,
48
+ norm_query=args.norm_query,
49
+ norm_doc=args.norm_doc,
50
+ is_main=src.dist_utils.is_main(),
51
+ split="dev" if args.dataset == "msmarco" else "test",
52
+ score_function=args.score_function,
53
+ beir_dir=args.beir_dir,
54
+ save_results_path=args.save_results_path,
55
+ lower_case=args.lower_case,
56
+ normalize_text=args.normalize_text,
57
+ )
58
+
59
+ if src.dist_utils.is_main():
60
+ for key, value in metrics.items():
61
+ logger.info(f"{args.dataset} : {key}: {value:.1f}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
66
+
67
+ parser.add_argument("--dataset", type=str, help="Evaluation dataset from the BEIR benchmark")
68
+ parser.add_argument("--beir_dir", type=str, default="./", help="Directory to save and load beir datasets")
69
+ parser.add_argument("--text_maxlength", type=int, default=512, help="Maximum text length")
70
+
71
+ parser.add_argument("--per_gpu_batch_size", default=128, type=int, help="Batch size per GPU/CPU for indexing.")
72
+ parser.add_argument("--output_dir", type=str, default="./my_experiment", help="Output directory")
73
+ parser.add_argument("--model_name_or_path", type=str, help="Model name or path")
74
+ parser.add_argument(
75
+ "--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings"
76
+ )
77
+ parser.add_argument("--norm_query", action="store_true", help="Normalize query representation")
78
+ parser.add_argument("--norm_doc", action="store_true", help="Normalize document representation")
79
+ parser.add_argument("--lower_case", action="store_true", help="lowercase query and document text")
80
+ parser.add_argument(
81
+ "--normalize_text", action="store_true", help="Apply function to normalize some common characters"
82
+ )
83
+ parser.add_argument("--save_results_path", type=str, default=None, help="Path to save result object")
84
+
85
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
86
+ parser.add_argument("--main_port", type=int, default=-1, help="Main port (for multi-node SLURM jobs)")
87
+
88
+ args, _ = parser.parse_known_args()
89
+ main(args)
sentence-transformers/evaluate_retrieved_passages.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import glob
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ import src.utils
16
+
17
+ from src.evaluation import calculate_matches
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def validate(data, workers_num):
22
+ match_stats = calculate_matches(data, workers_num)
23
+ top_k_hits = match_stats.top_k_hits
24
+
25
+ #logger.info('Validation results: top k documents hits %s', top_k_hits)
26
+ top_k_hits = [v / len(data) for v in top_k_hits]
27
+ #logger.info('Validation results: top k documents hits accuracy %s', top_k_hits)
28
+ return top_k_hits
29
+
30
+
31
+ def main(opt):
32
+ logger = src.utils.init_logger(opt, stdout_only=True)
33
+ datapaths = glob.glob(args.data)
34
+ r20, r100 = [], []
35
+ for path in datapaths:
36
+ data = []
37
+ with open(path, 'r') as fin:
38
+ for line in fin:
39
+ data.append(json.loads(line))
40
+ #data = json.load(fin)
41
+ answers = [ex['answers'] for ex in data]
42
+ top_k_hits = validate(data, args.validation_workers)
43
+ message = f"Evaluate results from {path}:"
44
+ for k in [5, 10, 20, 100]:
45
+ if k <= len(top_k_hits):
46
+ recall = 100 * top_k_hits[k-1]
47
+ if k == 20:
48
+ r20.append(f"{recall:.1f}")
49
+ if k == 100:
50
+ r100.append(f"{recall:.1f}")
51
+ message += f' R@{k}: {recall:.1f}'
52
+ logger.info(message)
53
+ print(datapaths)
54
+ print('\t'.join(r20))
55
+ print('\t'.join(r100))
56
+
57
+
58
+ if __name__ == '__main__':
59
+ parser = argparse.ArgumentParser()
60
+
61
+ parser.add_argument('--data', required=True, type=str, default=None)
62
+ parser.add_argument('--validation_workers', type=int, default=16,
63
+ help="Number of parallel processes to validate results")
64
+
65
+ args = parser.parse_args()
66
+ main(args)
sentence-transformers/finetuning.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import pdb
4
+ import os
5
+ import time
6
+ import sys
7
+ import torch
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import logging
10
+ import json
11
+ import numpy as np
12
+ import torch.distributed as dist
13
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
14
+
15
+ from src.options import Options
16
+ from src import data, beir_utils, slurm, dist_utils, utils, contriever, finetuning_data, inbatch
17
+
18
+ import train
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def finetuning(opt, model, optimizer, scheduler, tokenizer, step):
26
+
27
+ run_stats = utils.WeightedAvgStats()
28
+
29
+ tb_logger = utils.init_tb_logger(opt.output_dir)
30
+
31
+ if hasattr(model, "module"):
32
+ eval_model = model.module
33
+ else:
34
+ eval_model = model
35
+ eval_model = eval_model.get_encoder()
36
+
37
+ train_dataset = finetuning_data.Dataset(
38
+ datapaths=opt.train_data,
39
+ negative_ctxs=opt.negative_ctxs,
40
+ negative_hard_ratio=opt.negative_hard_ratio,
41
+ negative_hard_min_idx=opt.negative_hard_min_idx,
42
+ normalize=opt.eval_normalize_text,
43
+ global_rank=dist_utils.get_rank(),
44
+ world_size=dist_utils.get_world_size(),
45
+ maxload=opt.maxload,
46
+ training=True,
47
+ )
48
+ collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
49
+ train_sampler = RandomSampler(train_dataset)
50
+ train_dataloader = DataLoader(
51
+ train_dataset,
52
+ sampler=train_sampler,
53
+ batch_size=opt.per_gpu_batch_size,
54
+ drop_last=True,
55
+ num_workers=opt.num_workers,
56
+ collate_fn=collator,
57
+ )
58
+
59
+ train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
60
+ evaluate(opt, eval_model, tokenizer, tb_logger, step)
61
+
62
+ epoch = 1
63
+
64
+ model.train()
65
+ prev_ids, prev_mask = None, None
66
+ while step < opt.total_steps:
67
+ logger.info(f"Start epoch {epoch}, number of batches: {len(train_dataloader)}")
68
+ for i, batch in enumerate(train_dataloader):
69
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
70
+ step += 1
71
+
72
+ train_loss, iter_stats = model(**batch, stats_prefix="train")
73
+ train_loss.backward()
74
+
75
+ if opt.optim == "sam" or opt.optim == "asam":
76
+ optimizer.first_step(zero_grad=True)
77
+
78
+ sam_loss, _ = model(**batch, stats_prefix="train/sam_opt")
79
+ sam_loss.backward()
80
+ optimizer.second_step(zero_grad=True)
81
+ else:
82
+ optimizer.step()
83
+ scheduler.step()
84
+ optimizer.zero_grad()
85
+
86
+ run_stats.update(iter_stats)
87
+
88
+ if step % opt.log_freq == 0:
89
+ log = f"{step} / {opt.total_steps}"
90
+ for k, v in sorted(run_stats.average_stats.items()):
91
+ log += f" | {k}: {v:.3f}"
92
+ if tb_logger:
93
+ tb_logger.add_scalar(k, v, step)
94
+ log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
95
+ log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
96
+
97
+ logger.info(log)
98
+ run_stats.reset()
99
+
100
+ if step % opt.eval_freq == 0:
101
+
102
+ train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
103
+ evaluate(opt, eval_model, tokenizer, tb_logger, step)
104
+
105
+ if step % opt.save_freq == 0 and dist_utils.get_rank() == 0:
106
+ utils.save(
107
+ eval_model,
108
+ optimizer,
109
+ scheduler,
110
+ step,
111
+ opt,
112
+ opt.output_dir,
113
+ f"step-{step}",
114
+ )
115
+ model.train()
116
+
117
+ if step >= opt.total_steps:
118
+ break
119
+
120
+ epoch += 1
121
+
122
+
123
+ def evaluate(opt, model, tokenizer, tb_logger, step):
124
+ dataset = finetuning_data.Dataset(
125
+ datapaths=opt.eval_data,
126
+ normalize=opt.eval_normalize_text,
127
+ global_rank=dist_utils.get_rank(),
128
+ world_size=dist_utils.get_world_size(),
129
+ maxload=opt.maxload,
130
+ training=False,
131
+ )
132
+ collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
133
+ sampler = SequentialSampler(dataset)
134
+ dataloader = DataLoader(
135
+ dataset,
136
+ sampler=sampler,
137
+ batch_size=opt.per_gpu_batch_size,
138
+ drop_last=False,
139
+ num_workers=opt.num_workers,
140
+ collate_fn=collator,
141
+ )
142
+
143
+ model.eval()
144
+ if hasattr(model, "module"):
145
+ model = model.module
146
+ correct_samples, total_samples, total_step = 0, 0, 0
147
+ all_q, all_g, all_n = [], [], []
148
+ with torch.no_grad():
149
+ for i, batch in enumerate(dataloader):
150
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
151
+
152
+ all_tokens = torch.cat([batch["g_tokens"], batch["n_tokens"]], dim=0)
153
+ all_mask = torch.cat([batch["g_mask"], batch["n_mask"]], dim=0)
154
+
155
+ q_emb = model(input_ids=batch["q_tokens"], attention_mask=batch["q_mask"], normalize=opt.norm_query)
156
+ all_emb = model(input_ids=all_tokens, attention_mask=all_mask, normalize=opt.norm_doc)
157
+
158
+ g_emb, n_emb = torch.split(all_emb, [len(batch["g_tokens"]), len(batch["n_tokens"])])
159
+
160
+ all_q.append(q_emb)
161
+ all_g.append(g_emb)
162
+ all_n.append(n_emb)
163
+
164
+ all_q = torch.cat(all_q, dim=0)
165
+ all_g = torch.cat(all_g, dim=0)
166
+ all_n = torch.cat(all_n, dim=0)
167
+
168
+ labels = torch.arange(0, len(all_q), device=all_q.device, dtype=torch.long)
169
+
170
+ all_sizes = dist_utils.get_varsize(all_g)
171
+ all_g = dist_utils.varsize_gather_nograd(all_g)
172
+ all_n = dist_utils.varsize_gather_nograd(all_n)
173
+ labels = labels + sum(all_sizes[: dist_utils.get_rank()])
174
+
175
+ scores_pos = torch.einsum("id, jd->ij", all_q, all_g)
176
+ scores_neg = torch.einsum("id, jd->ij", all_q, all_n)
177
+ scores = torch.cat([scores_pos, scores_neg], dim=-1)
178
+
179
+ argmax_idx = torch.argmax(scores, dim=1)
180
+ sorted_scores, indices = torch.sort(scores, descending=True)
181
+ isrelevant = indices == labels[:, None]
182
+ rs = [r.cpu().numpy().nonzero()[0] for r in isrelevant]
183
+ mrr = np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs])
184
+
185
+ acc = (argmax_idx == labels).sum() / all_q.size(0)
186
+ acc, total = dist_utils.weighted_average(acc, all_q.size(0))
187
+ mrr, _ = dist_utils.weighted_average(mrr, all_q.size(0))
188
+ acc = 100 * acc
189
+
190
+ message = []
191
+ if dist_utils.is_main():
192
+ message = [f"eval acc: {acc:.2f}%", f"eval mrr: {mrr:.3f}"]
193
+ logger.info(" | ".join(message))
194
+ if tb_logger is not None:
195
+ tb_logger.add_scalar(f"eval_acc", acc, step)
196
+ tb_logger.add_scalar(f"mrr", mrr, step)
197
+
198
+
199
+ def main():
200
+ logger.info("Start")
201
+
202
+ options = Options()
203
+ opt = options.parse()
204
+
205
+ torch.manual_seed(opt.seed)
206
+ slurm.init_distributed_mode(opt)
207
+ slurm.init_signal_handler()
208
+
209
+ directory_exists = os.path.isdir(opt.output_dir)
210
+ if dist.is_initialized():
211
+ dist.barrier()
212
+ os.makedirs(opt.output_dir, exist_ok=True)
213
+ if not directory_exists and dist_utils.is_main():
214
+ options.print_options(opt)
215
+ if dist.is_initialized():
216
+ dist.barrier()
217
+ utils.init_logger(opt)
218
+
219
+ step = 0
220
+
221
+ retriever, tokenizer, retriever_model_id = contriever.load_retriever(opt.model_path, opt.pooling, opt.random_init)
222
+ opt.retriever_model_id = retriever_model_id
223
+ model = inbatch.InBatch(opt, retriever, tokenizer)
224
+
225
+ model = model.cuda()
226
+
227
+ optimizer, scheduler = utils.set_optim(opt, model)
228
+ # if dist_utils.is_main():
229
+ # utils.save(model, optimizer, scheduler, global_step, 0., opt, opt.output_dir, f"step-{0}")
230
+ logger.info(utils.get_parameters(model))
231
+
232
+ for name, module in model.named_modules():
233
+ if isinstance(module, torch.nn.Dropout):
234
+ module.p = opt.dropout
235
+
236
+ if torch.distributed.is_initialized():
237
+ model = torch.nn.parallel.DistributedDataParallel(
238
+ model,
239
+ device_ids=[opt.local_rank],
240
+ output_device=opt.local_rank,
241
+ find_unused_parameters=False,
242
+ )
243
+
244
+ logger.info("Start training")
245
+ finetuning(opt, model, optimizer, scheduler, tokenizer, step)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
sentence-transformers/generate_passage_embeddings.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ import argparse
10
+ import csv
11
+ import logging
12
+ import pickle
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ import transformers
18
+
19
+ import src.slurm
20
+ import src.contriever
21
+ import src.utils
22
+ import src.data
23
+ import src.normalize_text
24
+
25
+
26
+ def embed_passages(args, passages, model, tokenizer):
27
+ total = 0
28
+ allids, allembeddings = [], []
29
+ batch_ids, batch_text = [], []
30
+ with torch.no_grad():
31
+ for k, p in enumerate(passages):
32
+ batch_ids.append(p["id"])
33
+ if args.no_title or not "title" in p:
34
+ text = p["text"]
35
+ else:
36
+ text = p["title"] + " " + p["text"]
37
+ if args.lowercase:
38
+ text = text.lower()
39
+ if args.normalize_text:
40
+ text = src.normalize_text.normalize(text)
41
+ batch_text.append(text)
42
+
43
+ if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
44
+
45
+ encoded_batch = tokenizer.batch_encode_plus(
46
+ batch_text,
47
+ return_tensors="pt",
48
+ max_length=args.passage_maxlength,
49
+ padding=True,
50
+ truncation=True,
51
+ )
52
+
53
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
54
+ embeddings = model(**encoded_batch)
55
+
56
+ embeddings = embeddings.cpu()
57
+ total += len(batch_ids)
58
+ allids.extend(batch_ids)
59
+ allembeddings.append(embeddings)
60
+
61
+ batch_text = []
62
+ batch_ids = []
63
+ if k % 100000 == 0 and k > 0:
64
+ print(f"Encoded passages {total}")
65
+
66
+ allembeddings = torch.cat(allembeddings, dim=0).numpy()
67
+ return allids, allembeddings
68
+
69
+
70
+ def main(args):
71
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
72
+ print(f"Model loaded from {args.model_name_or_path}.", flush=True)
73
+ model.eval()
74
+ model = model.cuda()
75
+ if not args.no_fp16:
76
+ model = model.half()
77
+
78
+ passages = src.data.load_passages(args.passages)
79
+
80
+ shard_size = len(passages) // args.num_shards
81
+ start_idx = args.shard_id * shard_size
82
+ end_idx = start_idx + shard_size
83
+ if args.shard_id == args.num_shards - 1:
84
+ end_idx = len(passages)
85
+
86
+ passages = passages[start_idx:end_idx]
87
+ print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
88
+
89
+ allids, allembeddings = embed_passages(args, passages, model, tokenizer)
90
+
91
+ save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
92
+ os.makedirs(args.output_dir, exist_ok=True)
93
+ print(f"Saving {len(allids)} passage embeddings to {save_file}.")
94
+ with open(save_file, mode="wb") as f:
95
+ pickle.dump((allids, allembeddings), f)
96
+
97
+ print(f"Total passages processed {len(allids)}. Written to {save_file}.")
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser()
102
+
103
+ parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
104
+ parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
105
+ parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
106
+ parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
107
+ parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
108
+ parser.add_argument(
109
+ "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
110
+ )
111
+ parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
112
+ parser.add_argument(
113
+ "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
114
+ )
115
+ parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
116
+ parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
117
+ parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
118
+ parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
119
+
120
+ args = parser.parse_args()
121
+
122
+ src.slurm.init_distributed_mode(args)
123
+
124
+ main(args)
sentence-transformers/index.rst ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SentenceTransformers Documentation
2
+ =================================================
3
+
4
+ SentenceTransformers is a Python framework for state-of-the-art sentence, text and image embeddings. The initial work is described in our paper `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_.
5
+
6
+ You can use this framework to compute sentence / text embeddings for more than 100 languages. These embeddings can then be compared e.g. with cosine-similarity to find sentences with a similar meaning. This can be useful for `semantic textual similar <docs/usage/semantic_textual_similarity.html>`_, `semantic search <examples/applications/semantic-search/README.html>`_, or `paraphrase mining <examples/applications/paraphrase-mining/README.html>`_.
7
+
8
+ The framework is based on `PyTorch <https://pytorch.org/>`_ and `Transformers <https://huggingface.co/transformers/>`_ and offers a large collection of `pre-trained models <docs/pretrained_models.html>`_ tuned for various tasks. Further, it is easy to `fine-tune your own models <docs/training/overview.html>`_.
9
+
10
+
11
+ Installation
12
+ =================================================
13
+
14
+ You can install it using pip:
15
+
16
+ .. code-block:: python
17
+
18
+ pip install -U sentence-transformers
19
+
20
+
21
+ We recommend **Python 3.6** or higher, and at least **PyTorch 1.6.0**. See `installation <docs/installation.html>`_ for further installation options, especially if you want to use a GPU.
22
+
23
+
24
+
25
+ Usage
26
+ =================================================
27
+ The usage is as simple as:
28
+
29
+ .. code-block:: python
30
+
31
+ from sentence_transformers import SentenceTransformer
32
+ model = SentenceTransformer('all-MiniLM-L6-v2')
33
+
34
+ #Our sentences we like to encode
35
+ sentences = ['This framework generates embeddings for each input sentence',
36
+ 'Sentences are passed as a list of string.',
37
+ 'The quick brown fox jumps over the lazy dog.']
38
+
39
+ #Sentences are encoded by calling model.encode()
40
+ embeddings = model.encode(sentences)
41
+
42
+ #Print the embeddings
43
+ for sentence, embedding in zip(sentences, embeddings):
44
+ print("Sentence:", sentence)
45
+ print("Embedding:", embedding)
46
+ print("")
47
+
48
+
49
+
50
+
51
+ Performance
52
+ =========================
53
+
54
+ Our models are evaluated extensively and achieve state-of-the-art performance on various tasks. Further, the code is tuned to provide the highest possible speed. Have a look at `Pre-Trained Models <https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/>`_ for an overview of available models and the respective performance on different tasks.
55
+
56
+
57
+
58
+
59
+
60
+
61
+ Contact
62
+ =========================
63
+
64
+ Contact person: Nils Reimers, info@nils-reimers.de
65
+
66
+ https://www.ukp.tu-darmstadt.de/
67
+
68
+
69
+ Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
70
+
71
+ *This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.*
72
+
73
+
74
+ Citing & Authors
75
+ =========================
76
+
77
+ If you find this repository helpful, feel free to cite our publication `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_:
78
+
79
+ .. code-block:: bibtex
80
+
81
+ @inproceedings{reimers-2019-sentence-bert,
82
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
83
+ author = "Reimers, Nils and Gurevych, Iryna",
84
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
85
+ month = "11",
86
+ year = "2019",
87
+ publisher = "Association for Computational Linguistics",
88
+ url = "https://arxiv.org/abs/1908.10084",
89
+ }
90
+
91
+
92
+
93
+ If you use one of the multilingual models, feel free to cite our publication `Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation <https://arxiv.org/abs/2004.09813>`_:
94
+
95
+ .. code-block:: bibtex
96
+
97
+ @inproceedings{reimers-2020-multilingual-sentence-bert,
98
+ title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
99
+ author = "Reimers, Nils and Gurevych, Iryna",
100
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
101
+ month = "11",
102
+ year = "2020",
103
+ publisher = "Association for Computational Linguistics",
104
+ url = "https://arxiv.org/abs/2004.09813",
105
+ }
106
+
107
+
108
+
109
+ If you use the code for `data augmentation <https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation>`_, feel free to cite our publication `Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks <https://arxiv.org/abs/2010.08240>`_:
110
+
111
+ .. code-block:: bibtex
112
+
113
+ @inproceedings{thakur-2020-AugSBERT,
114
+ title = "Augmented {SBERT}: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks",
115
+ author = "Thakur, Nandan and Reimers, Nils and Daxenberger, Johannes and Gurevych, Iryna",
116
+ booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
117
+ month = jun,
118
+ year = "2021",
119
+ address = "Online",
120
+ publisher = "Association for Computational Linguistics",
121
+ url = "https://www.aclweb.org/anthology/2021.naacl-main.28",
122
+ pages = "296--310",
123
+ }
124
+
125
+
126
+
127
+ .. toctree::
128
+ :maxdepth: 2
129
+ :caption: Overview
130
+
131
+ docs/installation
132
+ docs/quickstart
133
+ docs/pretrained_models
134
+ docs/pretrained_cross-encoders
135
+ docs/publications
136
+ docs/hugging_face
137
+
138
+ .. toctree::
139
+ :maxdepth: 2
140
+ :caption: Usage
141
+
142
+ examples/applications/computing-embeddings/README
143
+ docs/usage/semantic_textual_similarity
144
+ examples/applications/semantic-search/README
145
+ examples/applications/retrieve_rerank/README
146
+ examples/applications/clustering/README
147
+ examples/applications/paraphrase-mining/README
148
+ examples/applications/parallel-sentence-mining/README
149
+ examples/applications/cross-encoder/README
150
+ examples/applications/image-search/README
151
+
152
+ .. toctree::
153
+ :maxdepth: 2
154
+ :caption: Training
155
+
156
+ docs/training/overview
157
+ examples/training/multilingual/README
158
+ examples/training/distillation/README
159
+ examples/training/cross-encoder/README
160
+ examples/training/data_augmentation/README
161
+
162
+ .. toctree::
163
+ :maxdepth: 2
164
+ :caption: Training Examples
165
+
166
+ examples/training/sts/README
167
+ examples/training/nli/README
168
+ examples/training/paraphrases/README
169
+ examples/training/quora_duplicate_questions/README
170
+ examples/training/ms_marco/README
171
+
172
+ .. toctree::
173
+ :maxdepth: 2
174
+ :caption: Unsupervised Learning
175
+
176
+ examples/unsupervised_learning/README
177
+ examples/domain_adaptation/README
178
+
179
+ .. toctree::
180
+ :maxdepth: 1
181
+ :caption: Package Reference
182
+
183
+ docs/package_reference/SentenceTransformer
184
+ docs/package_reference/util
185
+ docs/package_reference/models
186
+ docs/package_reference/losses
187
+ docs/package_reference/evaluation
188
+ docs/package_reference/datasets
189
+ docs/package_reference/cross_encoder
sentence-transformers/passage_retrieval.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import argparse
9
+ import csv
10
+ import json
11
+ import logging
12
+ import pickle
13
+ import time
14
+ import glob
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+ import transformers
20
+
21
+ import src.index
22
+ import src.contriever
23
+ import src.utils
24
+ import src.slurm
25
+ import src.data
26
+ from src.evaluation import calculate_matches
27
+ import src.normalize_text
28
+
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
30
+
31
+
32
+ def embed_queries(args, queries, model, tokenizer):
33
+ model.eval()
34
+ embeddings, batch_question = [], []
35
+ with torch.no_grad():
36
+
37
+ for k, q in enumerate(queries):
38
+ if args.lowercase:
39
+ q = q.lower()
40
+ if args.normalize_text:
41
+ q = src.normalize_text.normalize(q)
42
+ batch_question.append(q)
43
+
44
+ if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
45
+
46
+ encoded_batch = tokenizer.batch_encode_plus(
47
+ batch_question,
48
+ return_tensors="pt",
49
+ max_length=args.question_maxlength,
50
+ padding=True,
51
+ truncation=True,
52
+ )
53
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
54
+ output = model(**encoded_batch)
55
+ embeddings.append(output.cpu())
56
+
57
+ batch_question = []
58
+
59
+ embeddings = torch.cat(embeddings, dim=0)
60
+ print(f"Questions embeddings shape: {embeddings.size()}")
61
+
62
+ return embeddings.numpy()
63
+
64
+
65
+ def index_encoded_data(index, embedding_files, indexing_batch_size):
66
+ allids = []
67
+ allembeddings = np.array([])
68
+ for i, file_path in enumerate(embedding_files):
69
+ print(f"Loading file {file_path}")
70
+ with open(file_path, "rb") as fin:
71
+ ids, embeddings = pickle.load(fin)
72
+
73
+ allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
74
+ allids.extend(ids)
75
+ while allembeddings.shape[0] > indexing_batch_size:
76
+ allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
77
+
78
+ while allembeddings.shape[0] > 0:
79
+ allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
80
+
81
+ print("Data indexing completed.")
82
+
83
+
84
+ def add_embeddings(index, embeddings, ids, indexing_batch_size):
85
+ end_idx = min(indexing_batch_size, embeddings.shape[0])
86
+ ids_toadd = ids[:end_idx]
87
+ embeddings_toadd = embeddings[:end_idx]
88
+ ids = ids[end_idx:]
89
+ embeddings = embeddings[end_idx:]
90
+ index.index_data(ids_toadd, embeddings_toadd)
91
+ return embeddings, ids
92
+
93
+
94
+ def validate(data, workers_num):
95
+ match_stats = calculate_matches(data, workers_num)
96
+ top_k_hits = match_stats.top_k_hits
97
+
98
+ print("Validation results: top k documents hits %s", top_k_hits)
99
+ top_k_hits = [v / len(data) for v in top_k_hits]
100
+ message = ""
101
+ for k in [5, 10, 20, 100]:
102
+ if k <= len(top_k_hits):
103
+ message += f"R@{k}: {top_k_hits[k-1]} "
104
+ print(message)
105
+ return match_stats.questions_doc_hits
106
+
107
+
108
+ def add_passages(data, passages, top_passages_and_scores):
109
+ # add passages to original data
110
+ merged_data = []
111
+ assert len(data) == len(top_passages_and_scores)
112
+ for i, d in enumerate(data):
113
+ results_and_scores = top_passages_and_scores[i]
114
+ docs = [passages[doc_id] for doc_id in results_and_scores[0]]
115
+ scores = [str(score) for score in results_and_scores[1]]
116
+ ctxs_num = len(docs)
117
+ d["ctxs"] = [
118
+ {
119
+ "id": results_and_scores[0][c],
120
+ "title": docs[c]["title"],
121
+ "text": docs[c]["text"],
122
+ "score": scores[c],
123
+ }
124
+ for c in range(ctxs_num)
125
+ ]
126
+
127
+
128
+ def add_hasanswer(data, hasanswer):
129
+ # add hasanswer to data
130
+ for i, ex in enumerate(data):
131
+ for k, d in enumerate(ex["ctxs"]):
132
+ d["hasanswer"] = hasanswer[i][k]
133
+
134
+
135
+ def load_data(data_path):
136
+ if data_path.endswith(".json"):
137
+ with open(data_path, "r") as fin:
138
+ data = json.load(fin)
139
+ elif data_path.endswith(".jsonl"):
140
+ data = []
141
+ with open(data_path, "r") as fin:
142
+ for k, example in enumerate(fin):
143
+ example = json.loads(example)
144
+ data.append(example)
145
+ return data
146
+
147
+
148
+ def main(args):
149
+
150
+ print(f"Loading model from: {args.model_name_or_path}")
151
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
152
+ model.eval()
153
+ model = model.cuda()
154
+ if not args.no_fp16:
155
+ model = model.half()
156
+
157
+ index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
158
+
159
+ # index all passages
160
+ input_paths = glob.glob(args.passages_embeddings)
161
+ input_paths = sorted(input_paths)
162
+ embeddings_dir = os.path.dirname(input_paths[0])
163
+ index_path = os.path.join(embeddings_dir, "index.faiss")
164
+ if args.save_or_load_index and os.path.exists(index_path):
165
+ index.deserialize_from(embeddings_dir)
166
+ else:
167
+ print(f"Indexing passages from files {input_paths}")
168
+ start_time_indexing = time.time()
169
+ index_encoded_data(index, input_paths, args.indexing_batch_size)
170
+ print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
171
+ if args.save_or_load_index:
172
+ index.serialize(embeddings_dir)
173
+
174
+ # load passages
175
+ passages = src.data.load_passages(args.passages)
176
+ passage_id_map = {x["id"]: x for x in passages}
177
+
178
+ data_paths = glob.glob(args.data)
179
+ alldata = []
180
+ for path in data_paths:
181
+ data = load_data(path)
182
+ output_path = os.path.join(args.output_dir, os.path.basename(path))
183
+
184
+ queries = [ex["question"] for ex in data]
185
+ questions_embedding = embed_queries(args, queries, model, tokenizer)
186
+
187
+ # get top k results
188
+ start_time_retrieval = time.time()
189
+ top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
190
+ print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
191
+
192
+ add_passages(data, passage_id_map, top_ids_and_scores)
193
+ hasanswer = validate(data, args.validation_workers)
194
+ add_hasanswer(data, hasanswer)
195
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
196
+ with open(output_path, "w") as fout:
197
+ for ex in data:
198
+ json.dump(ex, fout, ensure_ascii=False)
199
+ fout.write("\n")
200
+ print(f"Saved results to {output_path}")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ parser = argparse.ArgumentParser()
205
+
206
+ parser.add_argument(
207
+ "--data",
208
+ required=True,
209
+ type=str,
210
+ default=None,
211
+ help=".json file containing question and answers, similar format to reader data",
212
+ )
213
+ parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
214
+ parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
215
+ parser.add_argument(
216
+ "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
217
+ )
218
+ parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
219
+ parser.add_argument(
220
+ "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
221
+ )
222
+ parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
223
+ parser.add_argument(
224
+ "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
225
+ )
226
+ parser.add_argument(
227
+ "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
228
+ )
229
+ parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
230
+ parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
231
+ parser.add_argument(
232
+ "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
233
+ )
234
+ parser.add_argument("--projection_size", type=int, default=768)
235
+ parser.add_argument(
236
+ "--n_subquantizers",
237
+ type=int,
238
+ default=0,
239
+ help="Number of subquantizer used for vector quantization, if 0 flat index is used",
240
+ )
241
+ parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
242
+ parser.add_argument("--lang", nargs="+")
243
+ parser.add_argument("--dataset", type=str, default="none")
244
+ parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
245
+ parser.add_argument("--normalize_text", action="store_true", help="normalize text")
246
+
247
+ args = parser.parse_args()
248
+ src.slurm.init_distributed_mode(args)
249
+ main(args)
sentence-transformers/preprocess.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import os
4
+ import argparse
5
+ import torch
6
+
7
+ import transformers
8
+ from src.normalize_text import normalize
9
+
10
+
11
+ def save(tensor, split_path):
12
+ if not os.path.exists(os.path.dirname(split_path)):
13
+ os.makedirs(os.path.dirname(split_path))
14
+ with open(split_path, 'wb') as fout:
15
+ torch.save(tensor, fout)
16
+
17
+ def apply_tokenizer(path, tokenizer, normalize_text=False):
18
+ alltokens = []
19
+ lines = []
20
+ with open(path, "r", encoding="utf-8") as fin:
21
+ for k, line in enumerate(fin):
22
+ if normalize_text:
23
+ line = normalize(line)
24
+
25
+ lines.append(line)
26
+ if len(lines) > 1000000:
27
+ tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
28
+ tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
29
+ alltokens.extend(tokens)
30
+ lines = []
31
+
32
+ tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
33
+ tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
34
+ alltokens.extend(tokens)
35
+
36
+ alltokens = torch.cat(alltokens)
37
+ return alltokens
38
+
39
+ def tokenize_file(args):
40
+ filename = os.path.basename(args.datapath)
41
+ savepath = os.path.join(args.outdir, f"{filename}.pkl")
42
+ if os.path.exists(savepath):
43
+ if args.overwrite:
44
+ print(f"File {savepath} already exists, overwriting")
45
+ else:
46
+ print(f"File {savepath} already exists, exiting")
47
+ return
48
+ try:
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=True)
50
+ except:
51
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=False)
52
+ print(f"Encoding {args.datapath}...")
53
+ tokens = apply_tokenizer(args.datapath, tokenizer, normalize_text=args.normalize_text)
54
+
55
+ print(f"Saving at {savepath}...")
56
+ save(tokens, savepath)
57
+
58
+
59
+ if __name__ == '__main__':
60
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
61
+ parser.add_argument("--datapath", type=str)
62
+ parser.add_argument("--outdir", type=str)
63
+ parser.add_argument("--tokenizer", type=str)
64
+ parser.add_argument("--overwrite", action="store_true")
65
+ parser.add_argument("--normalize_text", action="store_true")
66
+
67
+ args, _ = parser.parse_known_args()
68
+ tokenize_file(args)
sentence-transformers/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.6.0,<5.0.0
2
+ tokenizers>=0.10.3
3
+ tqdm
4
+ torch>=1.6.0
5
+ torchvision
6
+ numpy
7
+ scikit-learn
8
+ scipy
9
+ nltk
10
+ sentencepiece
11
+ huggingface-hub
sentence-transformers/setup.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [metadata]
2
+ description-file = README.md
sentence-transformers/setup.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ with open("README.md", mode="r", encoding="utf-8") as readme_file:
4
+ readme = readme_file.read()
5
+
6
+
7
+
8
+ setup(
9
+ name="sentence-transformers",
10
+ version="2.2.2",
11
+ author="Nils Reimers",
12
+ author_email="info@nils-reimers.de",
13
+ description="Multilingual text embeddings",
14
+ long_description=readme,
15
+ long_description_content_type="text/markdown",
16
+ license="Apache License 2.0",
17
+ url="https://www.SBERT.net",
18
+ download_url="https://github.com/UKPLab/sentence-transformers/",
19
+ packages=find_packages(),
20
+ python_requires=">=3.6.0",
21
+ install_requires=[
22
+ 'transformers>=4.6.0,<5.0.0',
23
+ 'tqdm',
24
+ 'torch>=1.6.0',
25
+ 'torchvision',
26
+ 'numpy',
27
+ 'scikit-learn',
28
+ 'scipy',
29
+ 'nltk',
30
+ 'sentencepiece',
31
+ 'huggingface-hub>=0.4.0'
32
+ ],
33
+ classifiers=[
34
+ "Development Status :: 5 - Production/Stable",
35
+ "Intended Audience :: Science/Research",
36
+ "License :: OSI Approved :: Apache Software License",
37
+ "Programming Language :: Python :: 3.6",
38
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
39
+ ],
40
+ keywords="Transformer Networks BERT XLNet sentence embedding PyTorch NLP deep learning"
41
+ )
sentence-transformers/train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import os
4
+ import time
5
+ import sys
6
+ import torch
7
+ import logging
8
+ import json
9
+ import numpy as np
10
+ import random
11
+ import pickle
12
+
13
+ import torch.distributed as dist
14
+ from torch.utils.data import DataLoader, RandomSampler
15
+
16
+ from src.options import Options
17
+ from src import data, beir_utils, slurm, dist_utils, utils
18
+ from src import moco, inbatch
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def train(opt, model, optimizer, scheduler, step):
25
+
26
+ run_stats = utils.WeightedAvgStats()
27
+
28
+ tb_logger = utils.init_tb_logger(opt.output_dir)
29
+
30
+ logger.info("Data loading")
31
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
32
+ tokenizer = model.module.tokenizer
33
+ else:
34
+ tokenizer = model.tokenizer
35
+ collator = data.Collator(opt=opt)
36
+ train_dataset = data.load_data(opt, tokenizer)
37
+ logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}")
38
+
39
+ train_sampler = RandomSampler(train_dataset)
40
+ train_dataloader = DataLoader(
41
+ train_dataset,
42
+ sampler=train_sampler,
43
+ batch_size=opt.per_gpu_batch_size,
44
+ drop_last=True,
45
+ num_workers=opt.num_workers,
46
+ collate_fn=collator,
47
+ )
48
+
49
+ epoch = 1
50
+
51
+ model.train()
52
+ while step < opt.total_steps:
53
+ train_dataset.generate_offset()
54
+
55
+ logger.info(f"Start epoch {epoch}")
56
+ for i, batch in enumerate(train_dataloader):
57
+ step += 1
58
+
59
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
60
+ train_loss, iter_stats = model(**batch, stats_prefix="train")
61
+
62
+ train_loss.backward()
63
+ optimizer.step()
64
+
65
+ scheduler.step()
66
+ model.zero_grad()
67
+
68
+ run_stats.update(iter_stats)
69
+
70
+ if step % opt.log_freq == 0:
71
+ log = f"{step} / {opt.total_steps}"
72
+ for k, v in sorted(run_stats.average_stats.items()):
73
+ log += f" | {k}: {v:.3f}"
74
+ if tb_logger:
75
+ tb_logger.add_scalar(k, v, step)
76
+ log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
77
+ log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
78
+
79
+ logger.info(log)
80
+ run_stats.reset()
81
+
82
+ if step % opt.eval_freq == 0:
83
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
84
+ encoder = model.module.get_encoder()
85
+ else:
86
+ encoder = model.get_encoder()
87
+ eval_model(
88
+ opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step
89
+ )
90
+
91
+ if dist_utils.is_main():
92
+ utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"lastlog")
93
+
94
+ model.train()
95
+
96
+ if dist_utils.is_main() and step % opt.save_freq == 0:
97
+ utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}")
98
+
99
+ if step > opt.total_steps:
100
+ break
101
+ epoch += 1
102
+
103
+
104
+ def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step):
105
+ for datasetname in opt.eval_datasets:
106
+ metrics = beir_utils.evaluate_model(
107
+ query_encoder,
108
+ doc_encoder,
109
+ tokenizer,
110
+ dataset=datasetname,
111
+ batch_size=opt.per_gpu_eval_batch_size,
112
+ norm_doc=opt.norm_doc,
113
+ norm_query=opt.norm_query,
114
+ beir_dir=opt.eval_datasets_dir,
115
+ score_function=opt.score_function,
116
+ lower_case=opt.lower_case,
117
+ normalize_text=opt.eval_normalize_text,
118
+ )
119
+
120
+ message = []
121
+ if dist_utils.is_main():
122
+ for metric in ["NDCG@10", "Recall@10", "Recall@100"]:
123
+ message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}")
124
+ if tb_logger is not None:
125
+ tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step)
126
+ logger.info(" | ".join(message))
127
+
128
+
129
+ if __name__ == "__main__":
130
+ logger.info("Start")
131
+
132
+ options = Options()
133
+ opt = options.parse()
134
+
135
+ torch.manual_seed(opt.seed)
136
+ slurm.init_distributed_mode(opt)
137
+ slurm.init_signal_handler()
138
+
139
+ directory_exists = os.path.isdir(opt.output_dir)
140
+ if dist.is_initialized():
141
+ dist.barrier()
142
+ os.makedirs(opt.output_dir, exist_ok=True)
143
+ if not directory_exists and dist_utils.is_main():
144
+ options.print_options(opt)
145
+ if dist.is_initialized():
146
+ dist.barrier()
147
+ utils.init_logger(opt)
148
+
149
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
150
+
151
+ if opt.contrastive_mode == "moco":
152
+ model_class = moco.MoCo
153
+ elif opt.contrastive_mode == "inbatch":
154
+ model_class = inbatch.InBatch
155
+ else:
156
+ raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised")
157
+
158
+ if not directory_exists and opt.model_path == "none":
159
+ model = model_class(opt)
160
+ model = model.cuda()
161
+ optimizer, scheduler = utils.set_optim(opt, model)
162
+ step = 0
163
+ elif directory_exists:
164
+ model_path = os.path.join(opt.output_dir, "checkpoint", "latest")
165
+ model, optimizer, scheduler, opt_checkpoint, step = utils.load(
166
+ model_class,
167
+ model_path,
168
+ opt,
169
+ reset_params=False,
170
+ )
171
+ logger.info(f"Model loaded from {opt.output_dir}")
172
+ else:
173
+ model, optimizer, scheduler, opt_checkpoint, step = utils.load(
174
+ model_class,
175
+ opt.model_path,
176
+ opt,
177
+ reset_params=False if opt.continue_training else True,
178
+ )
179
+ if not opt.continue_training:
180
+ step = 0
181
+ logger.info(f"Model loaded from {opt.model_path}")
182
+
183
+ logger.info(utils.get_parameters(model))
184
+
185
+ if dist.is_initialized():
186
+ model = torch.nn.parallel.DistributedDataParallel(
187
+ model,
188
+ device_ids=[opt.local_rank],
189
+ output_device=opt.local_rank,
190
+ find_unused_parameters=False,
191
+ )
192
+ dist.barrier()
193
+
194
+ logger.info("Start training")
195
+ train(opt, model, optimizer, scheduler, step)