crocidoc commited on
Commit
04de999
1 Parent(s): dae0599

initial commit

Browse files
Files changed (3) hide show
  1. app.py +111 -0
  2. requirements.txt +62 -0
  3. text_transformation_tools.py +55 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import text_transformation_tools as ttt
3
+ from transformers import pipeline
4
+ import plotly.express as px
5
+
6
+
7
+ def read_pdf(file):
8
+ text = ttt.pdf_to_text(uploaded_file)
9
+
10
+ return text
11
+
12
+ def analyze_text(paragraphs, topics, model, mode, min_chars, prob):
13
+
14
+ with st.spinner('Loading model'):
15
+ classifier = pipeline('zero-shot-classification', model=model)
16
+
17
+ relevant_parts = {}
18
+
19
+ for topic in topics:
20
+ relevant_parts[topic] = []
21
+
22
+ if mode == 'paragraphs':
23
+ text = paragraphs
24
+ elif mode == 'sentences':
25
+ text = []
26
+ for paragraph in paragraphs:
27
+ for sentence in paragraph.split('.'):
28
+ text.append(sentence)
29
+
30
+ min_chars = min_chars
31
+ min_score = prob
32
+
33
+ with st.spinner('Analyzing text...'):
34
+ counter = 0
35
+ counter_rel = 0
36
+ counter_tot = len(text)
37
+
38
+ with st.empty():
39
+
40
+ for sequence_to_classify in text:
41
+
42
+ cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ')
43
+
44
+ if len(cleansed_sequence) >= min_chars:
45
+
46
+
47
+ classified = classifier(cleansed_sequence, topics, multi_label=True)
48
+
49
+ for idx in range(len(classified['scores'])):
50
+ if classified['scores'][idx] >= min_score:
51
+ relevant_parts[classified['labels'][idx]].append(sequence_to_classify)
52
+ counter_rel += 1
53
+
54
+ counter += 1
55
+
56
+ st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode))
57
+
58
+
59
+ return relevant_parts
60
+
61
+
62
+ CHOICES = {
63
+ 'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)',
64
+ 'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)',
65
+ 'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)',
66
+ 'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)',
67
+ 'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)',
68
+ 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'}
69
+ def format_func(option):
70
+ return CHOICES[option]
71
+
72
+ st.header('File and topics')
73
+ uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
74
+ topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability')
75
+
76
+
77
+ st.header('Settings')
78
+ col1, col2 = st.columns(2)
79
+
80
+ with col1:
81
+ model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3)
82
+ mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences'])
83
+ with col2:
84
+ min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20)
85
+ probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100
86
+
87
+ topics = topics.split(',')
88
+ topics = [topic.strip() for topic in topics]
89
+
90
+ st.header('Analyze PDF')
91
+
92
+ if st.button('Analyze PDF'):
93
+ with st.spinner('Reading PDF...'):
94
+ text = read_pdf(uploaded_file)
95
+ page_count = ttt.count_pages(uploaded_file)
96
+ language = ttt.detect_language(' '.join(text))[0]
97
+ st.subheader('Overview')
98
+ st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language))
99
+
100
+ st.subheader('Analysis')
101
+ relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability)
102
+
103
+ counts = [len(relevant_parts[topic]) for topic in topics]
104
+
105
+ fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode))
106
+
107
+ st.plotly_chart(fig)
108
+
109
+ st.subheader('Relevant Passages')
110
+ st.write(relevant_parts)
111
+
requirements.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.2.0
2
+ attrs==22.1.0
3
+ blinker==1.5
4
+ cachetools==5.2.0
5
+ certifi==2022.9.14
6
+ cffi==1.15.1
7
+ charset-normalizer==2.1.1
8
+ click==8.1.3
9
+ commonmark==0.9.1
10
+ cryptography==38.0.1
11
+ decorator==5.1.1
12
+ entrypoints==0.4
13
+ filelock==3.8.0
14
+ gitdb==4.0.9
15
+ GitPython==3.1.27
16
+ huggingface-hub==0.9.1
17
+ idna==3.4
18
+ importlib-metadata==4.12.0
19
+ Jinja2==3.1.2
20
+ jsonschema==4.16.0
21
+ langid==1.1.6
22
+ MarkupSafe==2.1.1
23
+ numpy==1.23.3
24
+ packaging==21.3
25
+ pandas==1.5.0
26
+ pdfminer.six==20220524
27
+ Pillow==9.2.0
28
+ plotly==5.10.0
29
+ protobuf==3.20.1
30
+ pyarrow==9.0.0
31
+ pycparser==2.21
32
+ pydeck==0.8.0b3
33
+ Pygments==2.13.0
34
+ Pympler==1.0.1
35
+ PyMuPDF==1.20.2
36
+ pyparsing==3.0.9
37
+ pyrsistent==0.18.1
38
+ python-dateutil==2.8.2
39
+ pytz==2022.2.1
40
+ pytz-deprecation-shim==0.1.0.post0
41
+ PyYAML==6.0
42
+ regex==2022.9.13
43
+ requests==2.28.1
44
+ rich==12.5.1
45
+ semver==2.13.0
46
+ six==1.16.0
47
+ smmap==5.0.0
48
+ streamlit==1.13.0
49
+ tenacity==8.1.0
50
+ tokenizers==0.12.1
51
+ toml==0.10.2
52
+ toolz==0.12.0
53
+ torch==1.12.1
54
+ tornado==6.2
55
+ tqdm==4.64.1
56
+ transformers==4.22.1
57
+ typing_extensions==4.3.0
58
+ tzlocal==4.2
59
+ urllib3==1.26.12
60
+ validators==0.20.0
61
+ watchdog==2.1.9
62
+ zipp==3.8.1
text_transformation_tools.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This module contains helperfunctions to load pdfs, extract their texts and generate additional metadata
3
+
4
+ It was initially created for the businessresponsibility.ch project of the Prototype Fund. For more
5
+ information visit https://github.com/bizres
6
+
7
+ '''
8
+ from pdfminer.high_level import extract_pages
9
+ from pdfminer.layout import LTTextContainer
10
+ from pdfminer.high_level import extract_text
11
+
12
+ import fitz
13
+
14
+ import langid
15
+ langid.set_languages(['en', 'de','fr','it'])
16
+
17
+ import pandas as pd
18
+
19
+ def pdf_to_text(file):
20
+ '''
21
+ This function extracts text from a pdf.
22
+
23
+ Parameters:
24
+ path: path to pdf
25
+ '''
26
+
27
+ text = extract_text(file)
28
+ paragraphs = text.split('\n\n')
29
+ return paragraphs
30
+
31
+
32
+ def detect_language(text):
33
+ '''
34
+ This function detects the language of a text using langid
35
+ '''
36
+ return langid.classify(text)
37
+
38
+ def count_pages(pdf_file):
39
+ return len(list(extract_pages(pdf_file)))
40
+
41
+ def pdf_text_to_sections(text):
42
+ '''
43
+ This function generates a pandas DataFrame from the extracted text. Each section
44
+ is provided with the page it is on and a section_index
45
+ '''
46
+ sections = []
47
+ page_nr = 0
48
+ section_index = 0
49
+ for page in text.split('\n\n'):
50
+ page_nr += 1
51
+ for section in page.split('\n'):
52
+ sections.append([page_nr, section_index, section])
53
+ section_index += 1
54
+
55
+ return pd.DataFrame(sections, columns=['page', 'section_index', 'section_text'])