|
import whisper |
|
import os |
|
from pytube import YouTube |
|
import pandas as pd |
|
import plotly_express as px |
|
import nltk |
|
import plotly.graph_objects as go |
|
from optimum.onnxruntime import ORTModelForSequenceClassification |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification |
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
import streamlit as st |
|
import en_core_web_lg |
|
import validators |
|
import re |
|
import itertools |
|
import numpy as np |
|
from bs4 import BeautifulSoup |
|
import base64, time |
|
from annotated_text import annotated_text |
|
|
|
nltk.download('punkt') |
|
|
|
|
|
from nltk import sent_tokenize |
|
|
|
time_str = time.strftime("%d%m%Y-%H%M%S") |
|
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; |
|
margin-bottom: 2.5rem">{}</div> """ |
|
|
|
@st.experimental_singleton(suppress_st_warning=True) |
|
def load_models(): |
|
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") |
|
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
|
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") |
|
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
|
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer) |
|
sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn",clean_up_tokenization_spaces=True) |
|
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True) |
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') |
|
|
|
return sent_pipe, sum_pipe, ner_pipe, cross_encoder |
|
|
|
@st.experimental_singleton(suppress_st_warning=True) |
|
def load_asr_model(asr_model_name): |
|
asr_model = whisper.load_model(asr_model_name) |
|
|
|
return asr_model |
|
|
|
@st.experimental_singleton(suppress_st_warning=True) |
|
def load_sbert(model_name): |
|
sbert = SentenceTransformer(model_name) |
|
|
|
return sbert |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def embed_text(query,corpus,embedding_model): |
|
|
|
'''Embed text and generate semantic search scores''' |
|
|
|
|
|
if embedding_model == 'intfloat/e5-base': |
|
search_input = 'query: '+ query |
|
passages_emb = ['passage: ' + sentence for sentence in corpus] |
|
|
|
elif embedding_model == 'hkunlp/instructor-base': |
|
search_input = [['Represent the Financial question; Input: ', query, 0]] |
|
passages_emb = [['Represent the Financial statement for retrieval; Input: ',sentence,0] for sentence in corpus] |
|
|
|
else: |
|
search_input = query |
|
passages_emb = corpus |
|
|
|
|
|
|
|
corpus_embedding = sbert.encode(passages_emb, convert_to_tensor=True) |
|
question_embedding = sbert.encode(search_input, convert_to_tensor=True) |
|
question_embedding = question_embedding.cpu() |
|
corpus_embedding = corpus_embedding.cpu() |
|
|
|
|
|
hits = util.semantic_search(question_embedding, corpus_embedding, top_k=2) |
|
hits = hits[0] |
|
|
|
|
|
|
|
cross_inp = [[search_input, corpus[hit['corpus_id']]] for hit in hits] |
|
|
|
if embedding_model == 'hkunlp/instructor-base': |
|
result = [] |
|
|
|
for sublist in cross_inp: |
|
question = sublist[0][0][1] |
|
document = sublist[1][1] |
|
result.append([question, document]) |
|
|
|
cross_inp = result |
|
|
|
cross_scores = cross_encoder.predict(cross_inp) |
|
|
|
|
|
for idx in range(len(cross_scores)): |
|
hits[idx]['cross-score'] = cross_scores[idx] |
|
|
|
|
|
|
|
|
|
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) |
|
|
|
return hits |
|
|
|
@st.experimental_singleton(suppress_st_warning=True) |
|
def get_spacy(): |
|
nlp = en_core_web_lg.load() |
|
return nlp |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def inference(link, upload, _asr_model): |
|
'''Convert Youtube video or Audio upload to text''' |
|
|
|
if validators.url(link): |
|
|
|
yt = YouTube(link) |
|
title = yt.title |
|
path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4") |
|
results = _asr_model.transcribe(path, task='transcribe', language='en') |
|
|
|
return results['text'], yt.title |
|
|
|
elif upload: |
|
results = _asr_model.trasncribe(upload, task='transcribe', language='en') |
|
|
|
return results['text'], "Transcribed Earnings Audio" |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def sentiment_pipe(earnings_text): |
|
'''Determine the sentiment of the text''' |
|
|
|
earnings_sentences = chunk_long_text(earnings_text,150,1,1) |
|
earnings_sentiment = sent_pipe(earnings_sentences) |
|
|
|
return earnings_sentiment, earnings_sentences |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def summarize_text(text_to_summarize,max_len,min_len): |
|
'''Summarize text with HF model''' |
|
|
|
summarized_text = sum_pipe(text_to_summarize,max_length=max_len,min_length=min_len,clean_up_tokenization_spaces=True,no_repeat_ngram_size=4, |
|
encoder_no_repeat_ngram_size=3, |
|
repetition_penalty=3.5, |
|
num_beams=4, |
|
early_stopping=True) |
|
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text]) |
|
|
|
return summarized_text |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def clean_text(text): |
|
'''Clean all text''' |
|
|
|
text = text.encode("ascii", "ignore").decode() |
|
text = re.sub(r"https*\S+", " ", text) |
|
text = re.sub(r"@\S+", " ", text) |
|
text = re.sub(r"#\S+", " ", text) |
|
text = re.sub(r"\s{2,}", " ", text) |
|
|
|
return text |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def chunk_long_text(text,threshold,window_size=3,stride=2): |
|
'''Preprocess text and chunk for semantic search and sentiment analysis''' |
|
|
|
|
|
sentences = sent_tokenize(text) |
|
out = [] |
|
|
|
|
|
for chunk in sentences: |
|
if len(chunk.split()) < threshold: |
|
out.append(chunk) |
|
else: |
|
words = chunk.split() |
|
num = int(len(words)/threshold) |
|
for i in range(0,num*threshold+1,threshold): |
|
out.append(' '.join(words[i:threshold+i])) |
|
|
|
passages = [] |
|
|
|
|
|
for paragraph in [out]: |
|
for start_idx in range(0, len(paragraph), stride): |
|
end_idx = min(start_idx+window_size, len(paragraph)) |
|
passages.append(" ".join(paragraph[start_idx:end_idx])) |
|
|
|
return passages |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def chunk_and_preprocess_text(text,thresh=500): |
|
|
|
"""Chunk text longer than n tokens for summarization""" |
|
|
|
sentences = sent_tokenize(text) |
|
|
|
current_chunk = 0 |
|
chunks = [] |
|
|
|
for sentence in sentences: |
|
if len(chunks) == current_chunk + 1: |
|
if len(chunks[current_chunk]) + len(sentence.split(" ")) <= thresh: |
|
chunks[current_chunk].extend(sentence.split(" ")) |
|
else: |
|
current_chunk += 1 |
|
chunks.append(sentence.split(" ")) |
|
else: |
|
chunks.append(sentence.split(" ")) |
|
|
|
for chunk_id in range(len(chunks)): |
|
chunks[chunk_id] = " ".join(chunks[chunk_id]) |
|
|
|
return chunks |
|
|
|
|
|
def summary_downloader(raw_text): |
|
|
|
b64 = base64.b64encode(raw_text.encode()).decode() |
|
new_filename = "new_text_file_{}_.txt".format(time_str) |
|
st.markdown("#### Download Summary as a File ###") |
|
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>' |
|
st.markdown(href,unsafe_allow_html=True) |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def get_all_entities_per_sentence(text): |
|
doc = nlp(''.join(text)) |
|
|
|
sentences = list(doc.sents) |
|
|
|
entities_all_sentences = [] |
|
for sentence in sentences: |
|
entities_this_sentence = [] |
|
|
|
|
|
for entity in sentence.ents: |
|
entities_this_sentence.append(str(entity)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))] |
|
for entity in entities_xlm: |
|
entities_this_sentence.append(str(entity)) |
|
|
|
entities_all_sentences.append(entities_this_sentence) |
|
|
|
return entities_all_sentences |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def get_all_entities(text): |
|
all_entities_per_sentence = get_all_entities_per_sentence(text) |
|
return list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def get_and_compare_entities(article_content,summary_output): |
|
|
|
all_entities_per_sentence = get_all_entities_per_sentence(article_content) |
|
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
|
|
all_entities_per_sentence = get_all_entities_per_sentence(summary_output) |
|
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
|
|
matched_entities = [] |
|
unmatched_entities = [] |
|
for entity in entities_summary: |
|
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): |
|
matched_entities.append(entity) |
|
elif any( |
|
np.inner(sbert.encode(entity, show_progress_bar=False), |
|
sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for |
|
art_entity in entities_article): |
|
matched_entities.append(entity) |
|
else: |
|
unmatched_entities.append(entity) |
|
|
|
matched_entities = list(dict.fromkeys(matched_entities)) |
|
unmatched_entities = list(dict.fromkeys(unmatched_entities)) |
|
|
|
matched_entities_to_remove = [] |
|
unmatched_entities_to_remove = [] |
|
|
|
for entity in matched_entities: |
|
for substring_entity in matched_entities: |
|
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
|
matched_entities_to_remove.append(entity) |
|
|
|
for entity in unmatched_entities: |
|
for substring_entity in unmatched_entities: |
|
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
|
unmatched_entities_to_remove.append(entity) |
|
|
|
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove)) |
|
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove)) |
|
|
|
for entity in matched_entities_to_remove: |
|
matched_entities.remove(entity) |
|
for entity in unmatched_entities_to_remove: |
|
unmatched_entities.remove(entity) |
|
|
|
return matched_entities, unmatched_entities |
|
|
|
@st.experimental_memo(suppress_st_warning=True) |
|
def highlight_entities(article_content,summary_output): |
|
|
|
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
|
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
|
markdown_end = "</mark>" |
|
|
|
matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output) |
|
|
|
print(summary_output) |
|
|
|
for entity in matched_entities: |
|
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output) |
|
|
|
for entity in unmatched_entities: |
|
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output) |
|
|
|
print("") |
|
print(summary_output) |
|
|
|
print("") |
|
print(summary_output) |
|
|
|
soup = BeautifulSoup(summary_output, features="html.parser") |
|
|
|
return HTML_WRAPPER.format(soup) |
|
|
|
|
|
def display_df_as_table(model,top_k,score='score'): |
|
'''Display the df with text and scores as a table''' |
|
|
|
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text']) |
|
df['Score'] = round(df['Score'],2) |
|
|
|
return df |
|
|
|
|
|
def make_spans(text,results): |
|
results_list = [] |
|
for i in range(len(results)): |
|
results_list.append(results[i]['label']) |
|
facts_spans = [] |
|
facts_spans = list(zip(sent_tokenizer(text),results_list)) |
|
return facts_spans |
|
|
|
|
|
def fin_ext(text): |
|
results = remote_clx(sent_tokenizer(text)) |
|
return make_spans(text,results) |
|
|
|
nlp = get_spacy() |
|
sent_pipe, sum_pipe, ner_pipe, cross_encoder = load_models() |
|
sbert = load_sbert('all-MiniLM-L12-v2') |