hungdungn47 commited on
Commit
74d656b
·
1 Parent(s): d91835f

change infer function

Browse files
Files changed (1) hide show
  1. infer_concat.py +6 -6
infer_concat.py CHANGED
@@ -63,7 +63,7 @@ def processing_data_infer(input_file):
63
  tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
64
  model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
65
 
66
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
  model.to(device)
68
 
69
  model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
@@ -90,12 +90,12 @@ def infer_2_hier(model, data_loader, device, tokenizer):
90
  summary = model.generate(inputs[i].to(device),
91
  attention_mask=att_mask[i].to(device),
92
  max_length=128,
93
- num_beams=12,
94
- num_return_sequences=1)
95
  summaries.append(summary)
96
  summaries = torch.cat(summaries, dim = 1)
97
- for k in summaries:
98
- all_summaries.append(tokenizer.decode(k, skip_special_tokens=True))
99
 
100
 
101
  end = time.time()
@@ -104,6 +104,6 @@ def infer_2_hier(model, data_loader, device, tokenizer):
104
 
105
  def vit5_infer(data):
106
  dataset = Dataset4Summarization(data, tokenizer)
107
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1)
108
  result = infer_2_hier(model, data_loader, device, tokenizer)
109
  return result
 
63
  tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
64
  model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
65
 
66
+ device = torch.device('cpu')
67
  model.to(device)
68
 
69
  model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
 
90
  summary = model.generate(inputs[i].to(device),
91
  attention_mask=att_mask[i].to(device),
92
  max_length=128,
93
+ num_beams=4,
94
+ num_return_sequences=1, no_repeat_ngram_size=3)
95
  summaries.append(summary)
96
  summaries = torch.cat(summaries, dim = 1)
97
+
98
+ all_summaries.append(tokenizer.decode(summaries, skip_special_tokens=True))
99
 
100
 
101
  end = time.time()
 
104
 
105
  def vit5_infer(data):
106
  dataset = Dataset4Summarization(data, tokenizer)
107
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
108
  result = infer_2_hier(model, data_loader, device, tokenizer)
109
  return result