File size: 3,948 Bytes
cc83a1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import text_transformation_tools as ttt
from transformers import pipeline
import plotly.express as px


def read_pdf(file):
    text = ttt.pdf_to_text(uploaded_file)

    return text

def analyze_text(paragraphs, topics, model, mode, min_chars, prob):

    with st.spinner('Loading model'):
        classifier = pipeline('zero-shot-classification', model=model)

    relevant_parts = {}

    for topic in topics:
        relevant_parts[topic] = []

    if mode == 'paragraphs':
        text = paragraphs
    elif mode == 'sentences':
        text = []
        for paragraph in paragraphs:
            for sentence in paragraph.split('.'):
                text.append(sentence)

    min_chars = min_chars
    min_score = prob

    with st.spinner('Analyzing text...'):
        counter = 0
        counter_rel = 0
        counter_tot = len(text)

        with st.empty():

            for sequence_to_classify in text:
            
                cleansed_sequence = sequence_to_classify.replace('\n', '').replace('  ', ' ')

                if len(cleansed_sequence) >= min_chars:


                    classified = classifier(cleansed_sequence, topics, multi_label=True)

                    for idx in range(len(classified['scores'])):
                        if classified['scores'][idx] >= min_score:
                            relevant_parts[classified['labels'][idx]].append(sequence_to_classify)
                            counter_rel += 1

                counter += 1

                st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode))


    return relevant_parts


CHOICES = {
    'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)',
    'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)',
    'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)',
    'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)',
    'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)',
    'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'}
def format_func(option):
    return CHOICES[option]

st.header('File and topics')
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability')


st.header('Settings')
col1, col2 = st.columns(2)

with col1:
    model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3)
    mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences'])
with col2:
    min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20)
    probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100

topics = topics.split(',')
topics = [topic.strip() for topic in topics]

st.header('Analyze PDF')

if st.button('Analyze PDF'):
    with st.spinner('Reading PDF...'):
        text = read_pdf(uploaded_file)
        page_count = ttt.count_pages(uploaded_file)
        language = ttt.detect_language(' '.join(text))[0]
    st.subheader('Overview')
    st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language))
    
    st.subheader('Analysis')
    relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability)

    counts = [len(relevant_parts[topic]) for topic in topics]

    fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode))

    st.plotly_chart(fig)

    st.subheader('Relevant Passages')
    st.write(relevant_parts)