wjjessen commited on
Commit
3db79cf
1 Parent(s): f7a2883

update code

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -120,12 +120,13 @@ def main():
120
  trust_remote_code=True,
121
  cache_dir="model_cache"
122
  )
123
- base_model = AutoModelForSeq2SeqLM.from_pretrained(
124
- checkpoint,
125
- torch_dtype=torch.float32,
126
- trust_remote_code=True,
127
- cache_dir="model_cache"
128
- )
 
129
  else: # default Flan T5 small
130
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
131
  tokenizer = AutoTokenizer.from_pretrained(
@@ -135,11 +136,12 @@ def main():
135
  model_max_length=1000,
136
  cache_dir="model_cache"
137
  )
138
- base_model = AutoModelForSeq2SeqLM.from_pretrained(
139
- checkpoint,
140
- torch_dtype=torch.float32,
141
- cache_dir="model_cache"
142
- )
 
143
  with col2:
144
  st.write("Skip any pages?")
145
  skipfirst = st.checkbox("Skip first page")
@@ -167,8 +169,6 @@ def main():
167
  )
168
  pdf_viewer = displayPDF(filepath)
169
  with col2:
170
- with st.spinner("Downloading LLM..."):
171
- sleep(5)
172
  with st.spinner("Summarizing..."):
173
  summary = llm_pipeline(tokenizer, base_model, input_text)
174
  postproc_text_length = postproc_count(summary)
 
120
  trust_remote_code=True,
121
  cache_dir="model_cache"
122
  )
123
+ base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
124
+ #base_model = AutoModelForSeq2SeqLM.from_pretrained(
125
+ # checkpoint,
126
+ # torch_dtype=torch.float32,
127
+ # trust_remote_code=True,
128
+ # cache_dir="model_cache"
129
+ #)
130
  else: # default Flan T5 small
131
  checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
132
  tokenizer = AutoTokenizer.from_pretrained(
 
136
  model_max_length=1000,
137
  cache_dir="model_cache"
138
  )
139
+ base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
140
+ #base_model = AutoModelForSeq2SeqLM.from_pretrained(
141
+ # checkpoint,
142
+ # torch_dtype=torch.float32,
143
+ # cache_dir="model_cache"
144
+ #)
145
  with col2:
146
  st.write("Skip any pages?")
147
  skipfirst = st.checkbox("Skip first page")
 
169
  )
170
  pdf_viewer = displayPDF(filepath)
171
  with col2:
 
 
172
  with st.spinner("Summarizing..."):
173
  summary = llm_pipeline(tokenizer, base_model, input_text)
174
  postproc_text_length = postproc_count(summary)