deepanshudey commited on
Commit
783f067
1 Parent(s): 6632536
Files changed (1) hide show
  1. model.py +6 -4
model.py CHANGED
@@ -4,7 +4,10 @@ import random
4
  from huggingface_hub import hf_hub_download
5
 
6
 
7
- from transformers import AutoTokenizer, TFEncoderDecoderModel
 
 
 
8
 
9
  loc = "ydshieh/bert2bert-cnn_dailymail-fp16"
10
 
@@ -23,14 +26,13 @@ os.makedirs(model_dir, exist_ok=True)
23
  # file_path = hf_hub_download("ydshieh/bert2bert-cnn_dailymail-fp16", f"ckpt_epoch_3_step_6900/{fn}")
24
  # shutil.copyfile(file_path, os.path.join(model_dir, fn))
25
 
26
- model = TFEncoderDecoderModel.from_pretrained(loc)
27
- tokenizer = AutoTokenizer.from_pretrained(loc)
28
 
29
 
30
  def predict(article):
31
 
32
- input_ids = tokenizer(article, return_tensors="tf").input_ids
33
  output_ids = model.generate(input_ids)
 
34
  summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
35
  return summary
36
 
 
4
  from huggingface_hub import hf_hub_download
5
 
6
 
7
+ from transformers import BertTokenizer, EncoderDecoderModel
8
+
9
+ model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
10
+ tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
11
 
12
  loc = "ydshieh/bert2bert-cnn_dailymail-fp16"
13
 
 
26
  # file_path = hf_hub_download("ydshieh/bert2bert-cnn_dailymail-fp16", f"ckpt_epoch_3_step_6900/{fn}")
27
  # shutil.copyfile(file_path, os.path.join(model_dir, fn))
28
 
 
 
29
 
30
 
31
  def predict(article):
32
 
33
+ input_ids = tokenizer(article, return_tensors="pt").input_ids
34
  output_ids = model.generate(input_ids)
35
+
36
  summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
37
  return summary
38