crocidoc commited on
Commit
cc83a1d
1 Parent(s): 0f9a5e8

initial commit

Browse files
Files changed (3) hide show
  1. app.py +111 -0
  2. requirements.txt +6 -0
  3. text_transformation_tools.py +55 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import text_transformation_tools as ttt
3
+ from transformers import pipeline
4
+ import plotly.express as px
5
+
6
+
7
+ def read_pdf(file):
8
+ text = ttt.pdf_to_text(uploaded_file)
9
+
10
+ return text
11
+
12
+ def analyze_text(paragraphs, topics, model, mode, min_chars, prob):
13
+
14
+ with st.spinner('Loading model'):
15
+ classifier = pipeline('zero-shot-classification', model=model)
16
+
17
+ relevant_parts = {}
18
+
19
+ for topic in topics:
20
+ relevant_parts[topic] = []
21
+
22
+ if mode == 'paragraphs':
23
+ text = paragraphs
24
+ elif mode == 'sentences':
25
+ text = []
26
+ for paragraph in paragraphs:
27
+ for sentence in paragraph.split('.'):
28
+ text.append(sentence)
29
+
30
+ min_chars = min_chars
31
+ min_score = prob
32
+
33
+ with st.spinner('Analyzing text...'):
34
+ counter = 0
35
+ counter_rel = 0
36
+ counter_tot = len(text)
37
+
38
+ with st.empty():
39
+
40
+ for sequence_to_classify in text:
41
+
42
+ cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ')
43
+
44
+ if len(cleansed_sequence) >= min_chars:
45
+
46
+
47
+ classified = classifier(cleansed_sequence, topics, multi_label=True)
48
+
49
+ for idx in range(len(classified['scores'])):
50
+ if classified['scores'][idx] >= min_score:
51
+ relevant_parts[classified['labels'][idx]].append(sequence_to_classify)
52
+ counter_rel += 1
53
+
54
+ counter += 1
55
+
56
+ st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode))
57
+
58
+
59
+ return relevant_parts
60
+
61
+
62
+ CHOICES = {
63
+ 'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)',
64
+ 'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)',
65
+ 'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)',
66
+ 'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)',
67
+ 'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)',
68
+ 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'}
69
+ def format_func(option):
70
+ return CHOICES[option]
71
+
72
+ st.header('File and topics')
73
+ uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
74
+ topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability')
75
+
76
+
77
+ st.header('Settings')
78
+ col1, col2 = st.columns(2)
79
+
80
+ with col1:
81
+ model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3)
82
+ mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences'])
83
+ with col2:
84
+ min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20)
85
+ probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100
86
+
87
+ topics = topics.split(',')
88
+ topics = [topic.strip() for topic in topics]
89
+
90
+ st.header('Analyze PDF')
91
+
92
+ if st.button('Analyze PDF'):
93
+ with st.spinner('Reading PDF...'):
94
+ text = read_pdf(uploaded_file)
95
+ page_count = ttt.count_pages(uploaded_file)
96
+ language = ttt.detect_language(' '.join(text))[0]
97
+ st.subheader('Overview')
98
+ st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language))
99
+
100
+ st.subheader('Analysis')
101
+ relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability)
102
+
103
+ counts = [len(relevant_parts[topic]) for topic in topics]
104
+
105
+ fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode))
106
+
107
+ st.plotly_chart(fig)
108
+
109
+ st.subheader('Relevant Passages')
110
+ st.write(relevant_parts)
111
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers = {extras = ["torch"], version = "*"}
2
+ pdfminer-six = "*"
3
+ langid = "*"
4
+ pandas = "*"
5
+ streamlit = "*"
6
+ plotly = "*"
text_transformation_tools.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This module contains helperfunctions to load pdfs, extract their texts and generate additional metadata
3
+
4
+ It was initially created for the businessresponsibility.ch project of the Prototype Fund. For more
5
+ information visit https://github.com/bizres
6
+
7
+ '''
8
+ from pdfminer.high_level import extract_pages
9
+ from pdfminer.layout import LTTextContainer
10
+ from pdfminer.high_level import extract_text
11
+
12
+ import fitz
13
+
14
+ import langid
15
+ langid.set_languages(['en', 'de','fr','it'])
16
+
17
+ import pandas as pd
18
+
19
+ def pdf_to_text(file):
20
+ '''
21
+ This function extracts text from a pdf.
22
+
23
+ Parameters:
24
+ path: path to pdf
25
+ '''
26
+
27
+ text = extract_text(file)
28
+ paragraphs = text.split('\n\n')
29
+ return paragraphs
30
+
31
+
32
+ def detect_language(text):
33
+ '''
34
+ This function detects the language of a text using langid
35
+ '''
36
+ return langid.classify(text)
37
+
38
+ def count_pages(pdf_file):
39
+ return len(list(extract_pages(pdf_file)))
40
+
41
+ def pdf_text_to_sections(text):
42
+ '''
43
+ This function generates a pandas DataFrame from the extracted text. Each section
44
+ is provided with the page it is on and a section_index
45
+ '''
46
+ sections = []
47
+ page_nr = 0
48
+ section_index = 0
49
+ for page in text.split('\n\n'):
50
+ page_nr += 1
51
+ for section in page.split('\n'):
52
+ sections.append([page_nr, section_index, section])
53
+ section_index += 1
54
+
55
+ return pd.DataFrame(sections, columns=['page', 'section_index', 'section_text'])