update code
Browse files
app.py
CHANGED
@@ -63,7 +63,7 @@ def preproc_count(filepath, skipfirst, skiplast):
|
|
63 |
|
64 |
|
65 |
# llm pipeline
|
66 |
-
def llm_pipeline(tokenizer, base_model, input_text):
|
67 |
pipe_sum = pipeline(
|
68 |
"summarization",
|
69 |
model=base_model,
|
@@ -72,6 +72,7 @@ def llm_pipeline(tokenizer, base_model, input_text):
|
|
72 |
min_length=300,
|
73 |
truncation=True
|
74 |
)
|
|
|
75 |
print("Summarizing...")
|
76 |
result = pipe_sum(input_text)
|
77 |
summary = result[0]["summary_text"]
|
@@ -105,8 +106,14 @@ def main():
|
|
105 |
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
|
106 |
if uploaded_file is not None:
|
107 |
st.subheader("Options")
|
108 |
-
col1, col2, col3 = st.columns([1, 1, 2])
|
109 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
model_names = [
|
111 |
"T5-Small",
|
112 |
"BART",
|
@@ -121,13 +128,15 @@ def main():
|
|
121 |
model_max_length=1000,
|
122 |
trust_remote_code=True,
|
123 |
)
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
131 |
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
|
132 |
tokenizer = AutoTokenizer.from_pretrained(
|
133 |
checkpoint,
|
@@ -136,16 +145,18 @@ def main():
|
|
136 |
model_max_length=1000,
|
137 |
#cache_dir="model_cache"
|
138 |
)
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
145 |
st.write("Skip any pages?")
|
146 |
skipfirst = st.checkbox("Skip first page")
|
147 |
skiplast = st.checkbox("Skip last page")
|
148 |
-
with
|
149 |
st.write("Background information (links open in a new window)")
|
150 |
st.write(
|
151 |
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
|
@@ -170,7 +181,7 @@ def main():
|
|
170 |
with col2:
|
171 |
start = time.time()
|
172 |
with st.spinner("Summarizing..."):
|
173 |
-
summary = llm_pipeline(tokenizer, base_model, input_text)
|
174 |
postproc_text_length = postproc_count(summary)
|
175 |
end = time.time()
|
176 |
duration = end - start
|
|
|
63 |
|
64 |
|
65 |
# llm pipeline
|
66 |
+
def llm_pipeline(tokenizer, base_model, input_text, model_source):
|
67 |
pipe_sum = pipeline(
|
68 |
"summarization",
|
69 |
model=base_model,
|
|
|
72 |
min_length=300,
|
73 |
truncation=True
|
74 |
)
|
75 |
+
print("Model source: %s" %(model_source))
|
76 |
print("Summarizing...")
|
77 |
result = pipe_sum(input_text)
|
78 |
summary = result[0]["summary_text"]
|
|
|
106 |
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
|
107 |
if uploaded_file is not None:
|
108 |
st.subheader("Options")
|
109 |
+
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
|
110 |
with col1:
|
111 |
+
model_source_names = [
|
112 |
+
"Cached model",
|
113 |
+
"Download model"
|
114 |
+
]
|
115 |
+
model_source = st.radio("For development:", model_source_names)
|
116 |
+
with col2:
|
117 |
model_names = [
|
118 |
"T5-Small",
|
119 |
"BART",
|
|
|
128 |
model_max_length=1000,
|
129 |
trust_remote_code=True,
|
130 |
)
|
131 |
+
if model_source == "Download":
|
132 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
133 |
+
checkpoint,
|
134 |
+
torch_dtype=torch.float32,
|
135 |
+
trust_remote_code=True,
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
|
139 |
+
else:
|
140 |
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
|
141 |
tokenizer = AutoTokenizer.from_pretrained(
|
142 |
checkpoint,
|
|
|
145 |
model_max_length=1000,
|
146 |
#cache_dir="model_cache"
|
147 |
)
|
148 |
+
if model_source == "Download":
|
149 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
150 |
+
checkpoint,
|
151 |
+
torch_dtype=torch.float32,
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
|
155 |
+
with col3:
|
156 |
st.write("Skip any pages?")
|
157 |
skipfirst = st.checkbox("Skip first page")
|
158 |
skiplast = st.checkbox("Skip last page")
|
159 |
+
with col4:
|
160 |
st.write("Background information (links open in a new window)")
|
161 |
st.write(
|
162 |
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
|
|
|
181 |
with col2:
|
182 |
start = time.time()
|
183 |
with st.spinner("Summarizing..."):
|
184 |
+
summary = llm_pipeline(tokenizer, base_model, input_text, model_source)
|
185 |
postproc_text_length = postproc_count(summary)
|
186 |
end = time.time()
|
187 |
duration = end - start
|