wjjessen commited on
Commit
d124ecd
1 Parent(s): 621da38

update code

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -63,7 +63,7 @@ def preproc_count(filepath, skipfirst, skiplast):
63
 
64
 
65
  # llm pipeline
66
- def llm_pipeline(tokenizer, base_model, input_text):
67
  pipe_sum = pipeline(
68
  "summarization",
69
  model=base_model,
@@ -72,6 +72,7 @@ def llm_pipeline(tokenizer, base_model, input_text):
72
  min_length=300,
73
  truncation=True
74
  )
 
75
  print("Summarizing...")
76
  result = pipe_sum(input_text)
77
  summary = result[0]["summary_text"]
@@ -105,8 +106,14 @@ def main():
105
  uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
106
  if uploaded_file is not None:
107
  st.subheader("Options")
108
- col1, col2, col3 = st.columns([1, 1, 2])
109
  with col1:
 
 
 
 
 
 
110
  model_names = [
111
  "T5-Small",
112
  "BART",
@@ -121,13 +128,15 @@ def main():
121
  model_max_length=1000,
122
  trust_remote_code=True,
123
  )
124
- #base_model = AutoModelForSeq2SeqLM.from_pretrained(
125
- # checkpoint,
126
- # torch_dtype=torch.float32,
127
- # trust_remote_code=True,
128
- #)
129
- base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
130
- else: # default Flan T5 small
 
 
131
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
132
  tokenizer = AutoTokenizer.from_pretrained(
133
  checkpoint,
@@ -136,16 +145,18 @@ def main():
136
  model_max_length=1000,
137
  #cache_dir="model_cache"
138
  )
139
- #base_model = AutoModelForSeq2SeqLM.from_pretrained(
140
- # checkpoint,
141
- # torch_dtype=torch.float32,
142
- #)
143
- base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
144
- with col2:
 
 
145
  st.write("Skip any pages?")
146
  skipfirst = st.checkbox("Skip first page")
147
  skiplast = st.checkbox("Skip last page")
148
- with col3:
149
  st.write("Background information (links open in a new window)")
150
  st.write(
151
  "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
@@ -170,7 +181,7 @@ def main():
170
  with col2:
171
  start = time.time()
172
  with st.spinner("Summarizing..."):
173
- summary = llm_pipeline(tokenizer, base_model, input_text)
174
  postproc_text_length = postproc_count(summary)
175
  end = time.time()
176
  duration = end - start
 
63
 
64
 
65
  # llm pipeline
66
+ def llm_pipeline(tokenizer, base_model, input_text, model_source):
67
  pipe_sum = pipeline(
68
  "summarization",
69
  model=base_model,
 
72
  min_length=300,
73
  truncation=True
74
  )
75
+ print("Model source: %s" %(model_source))
76
  print("Summarizing...")
77
  result = pipe_sum(input_text)
78
  summary = result[0]["summary_text"]
 
106
  uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
107
  if uploaded_file is not None:
108
  st.subheader("Options")
109
+ col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
110
  with col1:
111
+ model_source_names = [
112
+ "Cached model",
113
+ "Download model"
114
+ ]
115
+ model_source = st.radio("For development:", model_source_names)
116
+ with col2:
117
  model_names = [
118
  "T5-Small",
119
  "BART",
 
128
  model_max_length=1000,
129
  trust_remote_code=True,
130
  )
131
+ if model_source == "Download":
132
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
133
+ checkpoint,
134
+ torch_dtype=torch.float32,
135
+ trust_remote_code=True,
136
+ )
137
+ else:
138
+ base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
139
+ else:
140
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
141
  tokenizer = AutoTokenizer.from_pretrained(
142
  checkpoint,
 
145
  model_max_length=1000,
146
  #cache_dir="model_cache"
147
  )
148
+ if model_source == "Download":
149
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
150
+ checkpoint,
151
+ torch_dtype=torch.float32,
152
+ )
153
+ else:
154
+ base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
155
+ with col3:
156
  st.write("Skip any pages?")
157
  skipfirst = st.checkbox("Skip first page")
158
  skiplast = st.checkbox("Skip last page")
159
+ with col4:
160
  st.write("Background information (links open in a new window)")
161
  st.write(
162
  "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
 
181
  with col2:
182
  start = time.time()
183
  with st.spinner("Summarizing..."):
184
+ summary = llm_pipeline(tokenizer, base_model, input_text, model_source)
185
  postproc_text_length = postproc_count(summary)
186
  end = time.time()
187
  duration = end - start