nickmuchi's picture
Update functions.py
e7363fe
raw
history blame
13.3 kB
import whisper
import os
from pytube import YouTube
import pandas as pd
import plotly_express as px
import nltk
import plotly.graph_objects as go
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import streamlit as st
import en_core_web_lg
import validators
import re
import itertools
import numpy as np
from bs4 import BeautifulSoup
import base64, time
from annotated_text import annotated_text
nltk.download('punkt')
from nltk import sent_tokenize
time_str = time.strftime("%d%m%Y-%H%M%S")
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
margin-bottom: 2.5rem">{}</div> """
@st.experimental_singleton(suppress_st_warning=True)
def load_models():
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn",clean_up_tokenization_spaces=True)
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
return sent_pipe, sum_pipe, ner_pipe, cross_encoder
@st.experimental_singleton(suppress_st_warning=True)
def load_asr_model(asr_model_name):
asr_model = whisper.load_model(asr_model_name)
return asr_model
@st.experimental_singleton(suppress_st_warning=True)
def load_sbert(model_name):
sbert = SentenceTransformer(model_name)
return sbert
@st.experimental_memo(suppress_st_warning=True)
def embed_text(query,corpus,embedding_model):
'''Embed text and generate semantic search scores'''
#If model is e5 then apply prefixes to query and passage
if embedding_model == 'intfloat/e5-base':
search_input = 'query: '+ query
passages_emb = ['passage: ' + sentence for sentence in corpus]
elif embedding_model == 'hkunlp/instructor-base':
search_input = [['Represent the Financial question; Input: ', query, 0]]
passages_emb = [['Represent the Financial statement for retrieval; Input: ',sentence,0] for sentence in corpus]
else:
search_input = query
passages_emb = corpus
#Embed corpus and question
corpus_embedding = sbert.encode(passages_emb, convert_to_tensor=True)
question_embedding = sbert.encode(search_input, convert_to_tensor=True)
question_embedding = question_embedding.cpu()
corpus_embedding = corpus_embedding.cpu()
# #Calculate similarity scores and rank
hits = util.semantic_search(question_embedding, corpus_embedding, top_k=2)
hits = hits[0] # Get the hits for the first query
# ##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[search_input, corpus[hit['corpus_id']]] for hit in hits]
if embedding_model == 'hkunlp/instructor-base':
result = []
for sublist in cross_inp:
question = sublist[0][0][1]
document = sublist[1][1]
result.append([question, document])
cross_inp = result
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-3 hits from re-ranker
# st.markdown("\n-------------------------\n")
# st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
return hits
@st.experimental_singleton(suppress_st_warning=True)
def get_spacy():
nlp = en_core_web_lg.load()
return nlp
@st.experimental_memo(suppress_st_warning=True)
def inference(link, upload, _asr_model):
'''Convert Youtube video or Audio upload to text'''
if validators.url(link):
yt = YouTube(link)
title = yt.title
path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
results = _asr_model.transcribe(path, task='transcribe', language='en')
return results['text'], yt.title
elif upload:
results = _asr_model.trasncribe(upload, task='transcribe', language='en')
return results['text'], "Transcribed Earnings Audio"
@st.experimental_memo(suppress_st_warning=True)
def sentiment_pipe(earnings_text):
'''Determine the sentiment of the text'''
earnings_sentences = chunk_long_text(earnings_text,150,1,1)
earnings_sentiment = sent_pipe(earnings_sentences)
return earnings_sentiment, earnings_sentences
@st.experimental_memo(suppress_st_warning=True)
def summarize_text(text_to_summarize,max_len,min_len):
'''Summarize text with HF model'''
summarized_text = sum_pipe(text_to_summarize,max_length=max_len,min_length=min_len,clean_up_tokenization_spaces=True,no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=3,
repetition_penalty=3.5,
num_beams=4,
early_stopping=True)
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
return summarized_text
@st.experimental_memo(suppress_st_warning=True)
def clean_text(text):
'''Clean all text'''
text = text.encode("ascii", "ignore").decode() # unicode
text = re.sub(r"https*\S+", " ", text) # url
text = re.sub(r"@\S+", " ", text) # mentions
text = re.sub(r"#\S+", " ", text) # hastags
text = re.sub(r"\s{2,}", " ", text) # over spaces
return text
@st.experimental_memo(suppress_st_warning=True)
def chunk_long_text(text,threshold,window_size=3,stride=2):
'''Preprocess text and chunk for semantic search and sentiment analysis'''
#Convert cleaned text into sentences
sentences = sent_tokenize(text)
out = []
#Limit the length of each sentence to a threshold
for chunk in sentences:
if len(chunk.split()) < threshold:
out.append(chunk)
else:
words = chunk.split()
num = int(len(words)/threshold)
for i in range(0,num*threshold+1,threshold):
out.append(' '.join(words[i:threshold+i]))
passages = []
#Combine sentences into a window of size window_size
for paragraph in [out]:
for start_idx in range(0, len(paragraph), stride):
end_idx = min(start_idx+window_size, len(paragraph))
passages.append(" ".join(paragraph[start_idx:end_idx]))
return passages
@st.experimental_memo(suppress_st_warning=True)
def chunk_and_preprocess_text(text,thresh=500):
"""Chunk text longer than n tokens for summarization"""
sentences = sent_tokenize(text)
current_chunk = 0
chunks = []
for sentence in sentences:
if len(chunks) == current_chunk + 1:
if len(chunks[current_chunk]) + len(sentence.split(" ")) <= thresh:
chunks[current_chunk].extend(sentence.split(" "))
else:
current_chunk += 1
chunks.append(sentence.split(" "))
else:
chunks.append(sentence.split(" "))
for chunk_id in range(len(chunks)):
chunks[chunk_id] = " ".join(chunks[chunk_id])
return chunks
def summary_downloader(raw_text):
b64 = base64.b64encode(raw_text.encode()).decode()
new_filename = "new_text_file_{}_.txt".format(time_str)
st.markdown("#### Download Summary as a File ###")
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
st.markdown(href,unsafe_allow_html=True)
@st.experimental_memo(suppress_st_warning=True)
def get_all_entities_per_sentence(text):
doc = nlp(''.join(text))
sentences = list(doc.sents)
entities_all_sentences = []
for sentence in sentences:
entities_this_sentence = []
# SPACY ENTITIES
for entity in sentence.ents:
entities_this_sentence.append(str(entity))
# FLAIR ENTITIES (CURRENTLY NOT USED)
# sentence_entities = Sentence(str(sentence))
# tagger.predict(sentence_entities)
# for entity in sentence_entities.get_spans('ner'):
# entities_this_sentence.append(entity.text)
# XLM ENTITIES
entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))]
for entity in entities_xlm:
entities_this_sentence.append(str(entity))
entities_all_sentences.append(entities_this_sentence)
return entities_all_sentences
@st.experimental_memo(suppress_st_warning=True)
def get_all_entities(text):
all_entities_per_sentence = get_all_entities_per_sentence(text)
return list(itertools.chain.from_iterable(all_entities_per_sentence))
@st.experimental_memo(suppress_st_warning=True)
def get_and_compare_entities(article_content,summary_output):
all_entities_per_sentence = get_all_entities_per_sentence(article_content)
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
matched_entities = []
unmatched_entities = []
for entity in entities_summary:
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
matched_entities.append(entity)
elif any(
np.inner(sbert.encode(entity, show_progress_bar=False),
sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for
art_entity in entities_article):
matched_entities.append(entity)
else:
unmatched_entities.append(entity)
matched_entities = list(dict.fromkeys(matched_entities))
unmatched_entities = list(dict.fromkeys(unmatched_entities))
matched_entities_to_remove = []
unmatched_entities_to_remove = []
for entity in matched_entities:
for substring_entity in matched_entities:
if entity != substring_entity and entity.lower() in substring_entity.lower():
matched_entities_to_remove.append(entity)
for entity in unmatched_entities:
for substring_entity in unmatched_entities:
if entity != substring_entity and entity.lower() in substring_entity.lower():
unmatched_entities_to_remove.append(entity)
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))
for entity in matched_entities_to_remove:
matched_entities.remove(entity)
for entity in unmatched_entities_to_remove:
unmatched_entities.remove(entity)
return matched_entities, unmatched_entities
@st.experimental_memo(suppress_st_warning=True)
def highlight_entities(article_content,summary_output):
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
markdown_end = "</mark>"
matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
print(summary_output)
for entity in matched_entities:
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output)
for entity in unmatched_entities:
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output)
print("")
print(summary_output)
print("")
print(summary_output)
soup = BeautifulSoup(summary_output, features="html.parser")
return HTML_WRAPPER.format(soup)
def display_df_as_table(model,top_k,score='score'):
'''Display the df with text and scores as a table'''
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
df['Score'] = round(df['Score'],2)
return df
def make_spans(text,results):
results_list = []
for i in range(len(results)):
results_list.append(results[i]['label'])
facts_spans = []
facts_spans = list(zip(sent_tokenizer(text),results_list))
return facts_spans
##Fiscal Sentiment by Sentence
def fin_ext(text):
results = remote_clx(sent_tokenizer(text))
return make_spans(text,results)
nlp = get_spacy()
sent_pipe, sum_pipe, ner_pipe, cross_encoder = load_models()
sbert = load_sbert('all-MiniLM-L12-v2')