Spaces:
Runtime error
Runtime error
Commit
•
783f067
1
Parent(s):
6632536
Updates
Browse files
model.py
CHANGED
@@ -4,7 +4,10 @@ import random
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
|
6 |
|
7 |
-
from transformers import
|
|
|
|
|
|
|
8 |
|
9 |
loc = "ydshieh/bert2bert-cnn_dailymail-fp16"
|
10 |
|
@@ -23,14 +26,13 @@ os.makedirs(model_dir, exist_ok=True)
|
|
23 |
# file_path = hf_hub_download("ydshieh/bert2bert-cnn_dailymail-fp16", f"ckpt_epoch_3_step_6900/{fn}")
|
24 |
# shutil.copyfile(file_path, os.path.join(model_dir, fn))
|
25 |
|
26 |
-
model = TFEncoderDecoderModel.from_pretrained(loc)
|
27 |
-
tokenizer = AutoTokenizer.from_pretrained(loc)
|
28 |
|
29 |
|
30 |
def predict(article):
|
31 |
|
32 |
-
input_ids = tokenizer(article, return_tensors="
|
33 |
output_ids = model.generate(input_ids)
|
|
|
34 |
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
35 |
return summary
|
36 |
|
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
|
6 |
|
7 |
+
from transformers import BertTokenizer, EncoderDecoderModel
|
8 |
+
|
9 |
+
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
10 |
+
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
11 |
|
12 |
loc = "ydshieh/bert2bert-cnn_dailymail-fp16"
|
13 |
|
|
|
26 |
# file_path = hf_hub_download("ydshieh/bert2bert-cnn_dailymail-fp16", f"ckpt_epoch_3_step_6900/{fn}")
|
27 |
# shutil.copyfile(file_path, os.path.join(model_dir, fn))
|
28 |
|
|
|
|
|
29 |
|
30 |
|
31 |
def predict(article):
|
32 |
|
33 |
+
input_ids = tokenizer(article, return_tensors="pt").input_ids
|
34 |
output_ids = model.generate(input_ids)
|
35 |
+
|
36 |
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
37 |
return summary
|
38 |
|