hungdungn47
commited on
Commit
·
74d656b
1
Parent(s):
d91835f
change infer function
Browse files- infer_concat.py +6 -6
infer_concat.py
CHANGED
@@ -63,7 +63,7 @@ def processing_data_infer(input_file):
|
|
63 |
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
64 |
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
65 |
|
66 |
-
device = torch.device('
|
67 |
model.to(device)
|
68 |
|
69 |
model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
|
@@ -90,12 +90,12 @@ def infer_2_hier(model, data_loader, device, tokenizer):
|
|
90 |
summary = model.generate(inputs[i].to(device),
|
91 |
attention_mask=att_mask[i].to(device),
|
92 |
max_length=128,
|
93 |
-
num_beams=
|
94 |
-
num_return_sequences=1)
|
95 |
summaries.append(summary)
|
96 |
summaries = torch.cat(summaries, dim = 1)
|
97 |
-
|
98 |
-
|
99 |
|
100 |
|
101 |
end = time.time()
|
@@ -104,6 +104,6 @@ def infer_2_hier(model, data_loader, device, tokenizer):
|
|
104 |
|
105 |
def vit5_infer(data):
|
106 |
dataset = Dataset4Summarization(data, tokenizer)
|
107 |
-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1
|
108 |
result = infer_2_hier(model, data_loader, device, tokenizer)
|
109 |
return result
|
|
|
63 |
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
64 |
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
65 |
|
66 |
+
device = torch.device('cpu')
|
67 |
model.to(device)
|
68 |
|
69 |
model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
|
|
|
90 |
summary = model.generate(inputs[i].to(device),
|
91 |
attention_mask=att_mask[i].to(device),
|
92 |
max_length=128,
|
93 |
+
num_beams=4,
|
94 |
+
num_return_sequences=1, no_repeat_ngram_size=3)
|
95 |
summaries.append(summary)
|
96 |
summaries = torch.cat(summaries, dim = 1)
|
97 |
+
|
98 |
+
all_summaries.append(tokenizer.decode(summaries, skip_special_tokens=True))
|
99 |
|
100 |
|
101 |
end = time.time()
|
|
|
104 |
|
105 |
def vit5_infer(data):
|
106 |
dataset = Dataset4Summarization(data, tokenizer)
|
107 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
108 |
result = infer_2_hier(model, data_loader, device, tokenizer)
|
109 |
return result
|