ferguch9 commited on
Commit
fef1930
·
1 Parent(s): 91bd3bd

feat: support pdf docs

Browse files
Files changed (2) hide show
  1. app.py +120 -49
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,64 +1,135 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  checkpoint = "facebook/bart-large-cnn"
6
 
 
7
  @st.cache_resource
8
  def load_model():
9
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
10
  return model
11
 
 
12
  @st.cache_resource
13
  def load_tokenizer():
14
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
  return tokenizer
16
 
17
- model = load_model()
18
- tokenizer = load_tokenizer()
19
-
20
- st.title('Summarisation Tool')
21
- st.write(f"Performs basic summarisation of text and audio using the '{checkpoint}' model.")
22
-
23
- st.sidebar.title('Options')
24
- summary_balance = st.sidebar.select_slider(
25
- 'Output Summarisation Detail:',
26
- options=['concise', 'balanced', 'full'],
27
- value='balanced')
28
-
29
- textTab, docTab, audioTab = st.tabs(["Plain Text", "Text Document", "Audio File"])
30
-
31
- with textTab:
32
- sentence = st.text_area('Paste text to be summarised:', help='Paste text into text area and hit Summarise button', height=300)
33
- st.write(f"{len(sentence)} characters and {len(sentence.split())} words")
34
-
35
- with docTab:
36
- st.text("Yet to be implemented...")
37
-
38
- with audioTab:
39
- st.text("Yet to be implemented...")
40
-
41
- button = st.button("Summarise")
42
- st.divider()
43
-
44
- with st.spinner("Generating Summary..."):
45
- if button and sentence:
46
- chunks = [sentence]
47
-
48
- text_words = len(sentence.split())
49
- if summary_balance == 'concise':
50
- min_multiplier = text_words * 0.1
51
- max_multiplier = text_words * 0.3
52
- elif summary_balance == 'full':
53
- min_multiplier = text_words * 0.5
54
- max_multiplier = text_words * 0.8
55
- elif summary_balance == 'balanced':
56
- min_multiplier = text_words * 0.2
57
- max_multiplier = text_words * 0.5
58
-
59
- print(f"min tokens {int(min_multiplier)}, max tokens {int(max_multiplier)}")
60
- inputs = tokenizer([sentence], max_length=2048, return_tensors='pt', truncation=True)
61
- summary_ids = model.generate(inputs['input_ids'], min_new_tokens=int(min_multiplier), max_new_tokens=int(max_multiplier), do_sample=False)
62
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
63
- st.write(summary)
64
- st.write(f"{len(summary)} characters and {len(summary.split())} words")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import PyPDF2
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  checkpoint = "facebook/bart-large-cnn"
7
 
8
+
9
  @st.cache_resource
10
  def load_model():
11
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
12
  return model
13
 
14
+
15
  @st.cache_resource
16
  def load_tokenizer():
17
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
18
  return tokenizer
19
 
20
+
21
+ def load_text_file(file):
22
+ bytes_data = file.getvalue()
23
+ text = bytes_data.decode("utf-8")
24
+ return text
25
+
26
+
27
+ def load_pdf_file(file):
28
+ pdf_reader = PyPDF2.PdfReader(file)
29
+ pdf_text = ""
30
+ for page_num in range(len(pdf_reader.pages)):
31
+ pdf_text += pdf_reader.pages[page_num].extract_text() or ""
32
+ return pdf_text
33
+
34
+
35
+ def split_text_into_chunks(text, max_chunk_length):
36
+ chunks = []
37
+ current_chunk = ""
38
+
39
+ for word in text.split():
40
+ if len(current_chunk) + len(word) + 1 <= max_chunk_length:
41
+ current_chunk += word + " "
42
+ else:
43
+ chunks.append(current_chunk.strip())
44
+ current_chunk = word + " "
45
+
46
+ if current_chunk:
47
+ chunks.append(current_chunk.strip())
48
+
49
+ return chunks
50
+
51
+
52
+ def main():
53
+ model = load_model()
54
+ print("Model's maximum sequence length:", model.config.max_position_embeddings)
55
+
56
+ tokenizer = load_tokenizer()
57
+ print("Tokenizer's maximum sequence length:", tokenizer.model_max_length)
58
+
59
+ st.title("Summarisation Tool")
60
+ st.write(
61
+ f"Performs basic summarisation of text and audio using the '{checkpoint}' model."
62
+ )
63
+
64
+ st.sidebar.title("Options")
65
+ summary_balance = st.sidebar.select_slider(
66
+ "Output Summarisation Detail:",
67
+ options=["concise", "balanced", "full"],
68
+ value="balanced",
69
+ )
70
+
71
+ textTab, docTab, audioTab = st.tabs(["Plain Text", "Text Document", "Audio File"])
72
+
73
+ with textTab:
74
+ sentence = st.text_area(
75
+ "Paste text to be summarised:",
76
+ help="Paste text into text area and hit Summarise button",
77
+ height=300,
78
+ )
79
+ st.write(f"{len(sentence)} characters and {len(sentence.split())} words")
80
+
81
+ with docTab:
82
+ uploaded_file = st.file_uploader("Select a file to be summarised:")
83
+ if uploaded_file is not None:
84
+ file_name = os.path.basename(uploaded_file.name)
85
+ _, file_ext = os.path.splitext(file_name)
86
+ if "pdf" in file_ext:
87
+ sentence = load_pdf_file(uploaded_file)
88
+ else:
89
+ sentence = load_text_file(uploaded_file)
90
+ st.write(f"{len(sentence)} characters and {len(sentence.split())} words")
91
+ st.write(sentence)
92
+
93
+ with audioTab:
94
+ st.text("Yet to be implemented...")
95
+
96
+ button = st.button("Summarise")
97
+ st.divider()
98
+
99
+ with st.spinner("Generating Summary..."):
100
+ if button and sentence:
101
+ chunks = split_text_into_chunks(sentence, 500)
102
+ print(chunks)
103
+
104
+ text_words = len(sentence.split())
105
+ if summary_balance == "concise":
106
+ min_multiplier = text_words * 0.1
107
+ max_multiplier = text_words * 0.3
108
+ elif summary_balance == "full":
109
+ min_multiplier = text_words * 0.5
110
+ max_multiplier = text_words * 0.8
111
+ elif summary_balance == "balanced":
112
+ min_multiplier = text_words * 0.2
113
+ max_multiplier = text_words * 0.5
114
+
115
+ print(f"min tokens {int(min_multiplier)}, max tokens {int(max_multiplier)}")
116
+ inputs = tokenizer(
117
+ chunks,
118
+ max_length=model.config.max_position_embeddings,
119
+ return_tensors="pt",
120
+ truncation=True,
121
+ padding=True,
122
+ )
123
+ summary_ids = model.generate(
124
+ inputs["input_ids"],
125
+ min_new_tokens=int(min_multiplier),
126
+ max_new_tokens=int(max_multiplier),
127
+ do_sample=False,
128
+ )
129
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
130
+ st.write(summary)
131
+ st.write(f"{len(summary)} characters and {len(summary.split())} words")
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
requirements.txt CHANGED
@@ -5,3 +5,4 @@ torch
5
  torchvision
6
  torchaudio
7
  transformers
 
 
5
  torchvision
6
  torchaudio
7
  transformers
8
+ PyPDF2