raccoon / app.py
grapplerulrich's picture
Use transformer tokenizer to make chunks
8daf73a unverified
raw history blame
No virus
10.8 kB
from os import makedirs, remove
from os.path import exists, dirname
from functools import cache
import json
import streamlit as st
from googleapiclient.discovery import build
from slugify import slugify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import uuid
import spacy
from spacy.matcher import PhraseMatcher
from beautiful_soup.beautiful_soup import get_url_content
@cache
def google_search_api_request( query ):
"""
Request Google Search API with query and return results.
"""
service = build(
"customsearch",
"v1",
developerKey=st.secrets["google_search_api_key"],
cache_discovery=False
)
# Exclude PDFs from search results.
query = query + ' -filetype:pdf'
return service.cse().list(
q=query,
cx=st.secrets["google_search_engine_id"],
num=5,
lr='lang_en', # lang_de
fields='items(title,link),searchInformation(totalResults)'
).execute()
def search_results( query ):
"""
Request Google Search API with query and return results. Results are cached in files.
"""
file_path = 'search-results/' + slugify( query ) + '.json'
# Create cache directory if it doesn't exist.
makedirs(dirname(file_path), exist_ok=True)
results = []
# Check if cache file exists.
if exists( file_path ):
with open( file_path, 'r' ) as results_file:
results = json.load( results_file )
else:
search_result = google_search_api_request( query )
# Check if search contains results.
if int( search_result['searchInformation']['totalResults'] ) > 0:
results = search_result['items']
# Save results to cache file.
with open( file_path, 'w' ) as results_file:
json.dump( results, results_file )
if len( results ) == 0:
raise Exception('No results found.')
return results
def get_summary( url, keywords ):
url_id = uuid.uuid5( uuid.NAMESPACE_URL, url ).hex
file_path = 'summaries/' + url_id + '.json'
# Create cache directory if it doesn't exist.
makedirs(dirname(file_path), exist_ok=True)
# Check if cache file exists.
if exists( file_path ):
with open( file_path, 'r' ) as file:
summary = json.load( file )
else:
try:
strings = get_url_content( url )
content_cache = 'content/' + url_id + '.txt'
# Create cache directory if it doesn't exist.
makedirs(dirname(content_cache), exist_ok=True)
# Check if content cache file exists.
if exists( content_cache ):
with open( content_cache, 'r' ) as file:
content = file.read().rstrip()
else:
content = prep_chunks_summary( strings, keywords )
# Save content to cache file.
with open( content_cache, 'w' ) as file:
print(content.strip(), file=file)
max_lenth = 200
# Rudementary method to count number of tokens in a chunk.
word_count = len( content.split(' ') )
# If content is longer then 200 words summarize it.
if word_count > max_lenth:
# Generate summary from compiled content.
summary = generate_summary( content, max_lenth )
else:
summary = [ { "summary_text": content } ]
except Exception as exception:
raise exception
# Save results to cache file.
with open( file_path, 'w' ) as file:
json.dump( summary, file )
return summary
def generate_summary( content, max_length ):
"""
Generate summary for content.
"""
try:
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
# https://huggingface.co/docs/transformers/v4.18.0/en/main_classes/pipelines#transformers.SummarizationPipeline
summary = summarizer(content, max_length, min_length=30, do_sample=False, truncation=True)
except Exception as exception:
raise exception
return summary
def exception_notice( exception ):
"""
Helper function for exception notices.
"""
query_params = st.experimental_get_query_params()
# If debug mode is enabled, show exception else show warning.
if 'debug' in query_params.keys() and query_params['debug'][0] == 'true':
st.exception(exception)
else:
st.warning(str(exception))
# Unused function.
def is_keyword_in_string( keywords, string ):
"""
Checks if string contains keyword.
"""
for keyword in keywords:
if keyword in string:
return True
return False
def filter_sentences_by_keywords( strings, keywords ):
"""
Filter sentences by keywords using spacy.
"""
nlp = spacy.load("en_core_web_sm")
matcher = PhraseMatcher(nlp.vocab)
# Add keywords to matcher.
patterns = [nlp(keyword) for keyword in keywords]
matcher.add("QueryList", patterns)
sentences = []
for string in strings:
# Exclude sentences shorten than 5 words.
string_length = len( string.split(' ') )
if string_length < 5:
continue
# Loop through sentences and check if any of the keywords are in the sentence.
doc = nlp(string)
for sentence in doc.sents:
matches = matcher(nlp(sentence.text))
for match_id, start, end in matches:
# If keyword is in sentence, add sentence to list.
if nlp.vocab.strings[match_id] in ["QueryList"]:
sentences.append(sentence.text)
if ( len(sentences) == 0 ):
raise Exception('No sentences with keywords found.')
return sentences
def split_content_into_chunks( sentences, tokenizer ):
"""
Split content into chunks.
"""
combined_length = 0
chunk = ""
chunks = []
for sentence in sentences:
# Lenth of tokens in sentence.
length = len( tokenizer.tokenize( sentence ) )
# If the combined token length plus the current sentence is larger then max length, start a new chunk.
if combined_length + length > tokenizer.max_len_single_sentence:
chunks.append(chunk)
chunk = '' # Reset chunk.
combined_length = 0 # Reset token length.
# Add sentence to chunk.
combined_length += length
chunk += sentence + ' '
chunks.append(chunk)
return chunks
def prep_chunks_summary( strings, keywords ):
"""
Chunk summary.
"""
try:
checkpoint = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
sentences = filter_sentences_by_keywords( strings, keywords )
chunks = split_content_into_chunks( sentences, tokenizer )
content = ''
number_of_chunks = len( chunks )
# Loop through chunks if there are more than one.
if number_of_chunks > 1:
# Calculate the max summary length based on the number of chunks so that the final combined text is not longer than max tokens.
max_length = int( tokenizer.max_len_single_sentence / number_of_chunks )
# Loop through chunks and generate summary.
for chunk in chunks:
# Number of tokens in a chunk.
chunk_length = len( tokenizer.tokenize( chunk ) )
# If chunk is shorter than max length, divide chunk length by 2.
if chunk_length < max_length:
max_length = int( chunk_length / 2 )
# Generate summary for chunk.
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
# https://huggingface.co/docs/transformers/v4.18.0/en/main_classes/pipelines#transformers.SummarizationPipeline
chunk_summary = summarizer(chunk, max_length, min_length=10, do_sample=False, truncation=True)
for summary in chunk_summary:
content += summary['summary_text'] + ' '
elif number_of_chunks == 1:
content = chunks[0]
return content
except Exception as exception:
raise exception
def main():
st.title('Racoon Search')
query = st.text_input('Search query')
query_params = st.experimental_get_query_params()
if query :
with st.spinner('Loading search results...'):
try:
results = search_results( query )
except Exception as exception:
exception_notice(exception)
return
# Count results.
number_of_results = len( results )
st.success( 'Found {} results for "{}".'.format( number_of_results, query ) )
# If debug mode is enabled, show search results in JSON.
if 'debug' in query_params.keys() and query_params['debug'][0] == 'true':
with st.expander("Search results JSON"):
if st.button('Delete search result cache', key=query + 'cache'):
remove( 'search-results/' + slugify( query ) + '.json' )
st.json( results )
progress_bar = st.progress(0)
st.header('Search results')
st.markdown('---')
# for result in results:
for index, result in enumerate(results):
with st.container():
st.markdown('### ' + result['title'])
# Create a unique id for the result.
url_id = uuid.uuid5( uuid.NAMESPACE_URL, result['link'] ).hex
# List of query keywords.
keywords = query.split(' ')
try :
# Create summary of summarized content.
summary = get_summary( result['link'], keywords )
st.markdown(summary[0]['summary_text'])
except Exception as exception:
exception_notice(exception)
progress_bar.progress( ( index + 1 ) / number_of_results )
# Show links and buttons.
col1, col2, col3 = st.columns(3)
with col1:
st.markdown('[Website Link]({})'.format(result['link']))
with col2:
if st.button('Delete content from cache', key=url_id + 'content'):
remove( 'page-content/' + url_id + '.txt' )
with col3:
if st.button('Delete summary from cache', key=url_id + 'summary'):
remove( 'summaries/' + url_id + '.json' )
st.markdown('---')
if __name__ == '__main__':
main()