| import streamlit as st | |
| import text_transformation_tools as ttt | |
| from transformers import pipeline | |
| import plotly.express as px | |
| def read_pdf(file): | |
| text = ttt.pdf_to_text(uploaded_file) | |
| return text | |
| def analyze_text(paragraphs, topics, model, mode, min_chars, prob): | |
| with st.spinner('Loading model'): | |
| classifier = pipeline('zero-shot-classification', model=model) | |
| relevant_parts = {} | |
| for topic in topics: | |
| relevant_parts[topic] = [] | |
| if mode == 'paragraphs': | |
| text = paragraphs | |
| elif mode == 'sentences': | |
| text = [] | |
| for paragraph in paragraphs: | |
| for sentence in paragraph.split('.'): | |
| text.append(sentence) | |
| min_chars = min_chars | |
| min_score = prob | |
| with st.spinner('Analyzing text...'): | |
| counter = 0 | |
| counter_rel = 0 | |
| counter_tot = len(text) | |
| with st.empty(): | |
| for sequence_to_classify in text: | |
| cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ') | |
| if len(cleansed_sequence) >= min_chars: | |
| classified = classifier(cleansed_sequence, topics, multi_label=True) | |
| for idx in range(len(classified['scores'])): | |
| if classified['scores'][idx] >= min_score: | |
| relevant_parts[classified['labels'][idx]].append(sequence_to_classify) | |
| counter_rel += 1 | |
| counter += 1 | |
| st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode)) | |
| return relevant_parts | |
| CHOICES = { | |
| 'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)', | |
| 'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)', | |
| 'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)', | |
| 'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)', | |
| 'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)', | |
| 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'} | |
| def format_func(option): | |
| return CHOICES[option] | |
| st.header('File and topics') | |
| uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") | |
| topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability') | |
| st.header('Settings') | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3) | |
| mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences']) | |
| with col2: | |
| min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20) | |
| probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100 | |
| topics = topics.split(',') | |
| topics = [topic.strip() for topic in topics] | |
| st.header('Analyze PDF') | |
| if st.button('Analyze PDF'): | |
| with st.spinner('Reading PDF...'): | |
| text = read_pdf(uploaded_file) | |
| page_count = ttt.count_pages(uploaded_file) | |
| language = ttt.detect_language(' '.join(text))[0] | |
| st.subheader('Overview') | |
| st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language)) | |
| st.subheader('Analysis') | |
| relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability) | |
| counts = [len(relevant_parts[topic]) for topic in topics] | |
| fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode)) | |
| st.plotly_chart(fig) | |
| st.subheader('Relevant Passages') | |
| st.write(relevant_parts) | |