Blaise-g commited on
Commit
10d5b39
β€’
1 Parent(s): 2e0b08a

Update summarize.py

Browse files

test coming back to orig settings

Files changed (1) hide show
  1. summarize.py +22 -34
summarize.py CHANGED
@@ -27,7 +27,7 @@ def load_model_and_tokenizer(model_name):
27
  return model, tokenizer
28
 
29
 
30
- def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
31
  """
32
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
  Args:
@@ -35,7 +35,6 @@ def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
35
  mask (): the attention mask for the batch
36
  model (): the model to use for summarization
37
  tokenizer (): the tokenizer to use for summarization
38
- model
39
  Returns:
40
  str: the summary of the batch
41
  """
@@ -45,32 +44,27 @@ def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
45
 
46
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
47
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
48
-
49
- if model_arch == 'LED':
50
- global_attention_mask = torch.zeros_like(attention_mask)
51
- # put global attention on <s> token
52
- global_attention_mask[:, 0] = 1
53
- summary_pred_ids = model.generate(
54
- input_ids,
55
- attention_mask=attention_mask,
56
- global_attention_mask=global_attention_mask,
57
- return_dict_in_generate=True,
58
- **kwargs,
59
- )
60
-
61
- else:
62
- summary_pred_ids = model.generate(
63
- input_ids,
64
- attention_mask=attention_mask,
65
- return_dict_in_generate=True,
66
- **kwargs,
67
- )
68
  summary = tokenizer.batch_decode(
69
  summary_pred_ids.sequences,
70
  skip_special_tokens=True,
71
  remove_invalid_values=True,
72
  )
73
- return summary
 
 
74
 
75
 
76
  def summarize_via_tokenbatches(
@@ -116,28 +110,22 @@ def summarize_via_tokenbatches(
116
  pbar = tqdm(total=len(in_id_arr))
117
 
118
  for _id, _mask in zip(in_id_arr, att_arr):
119
-
120
- if model=='Blaise-g/led_pubmed_sumpubmed_1' or model=='Blaise-g/led_large_sumpbumed_scitldr':
121
- model_arch = 'LED'
122
- else:
123
- model_arch = 'LongT5'
124
-
125
- result = summarize(
126
  ids=_id,
127
  mask=_mask,
128
  model=model,
129
- model_arch=model_arch,
130
  tokenizer=tokenizer,
131
  **kwargs,
132
  )
133
- rate = round(float((len(input_text)-len(result))/len(input_text)), 3)
134
  _sum = {
135
  "input_tokens": _id,
136
  "summary": result,
137
- "compression_rate": rate,
138
  }
139
  gen_summaries.append(_sum)
140
- print(f"\t{result[0]}\nRate:\t{rate}")
141
  pbar.update()
142
 
143
  pbar.close()
 
27
  return model, tokenizer
28
 
29
 
30
+ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
31
  """
32
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
  Args:
 
35
  mask (): the attention mask for the batch
36
  model (): the model to use for summarization
37
  tokenizer (): the tokenizer to use for summarization
 
38
  Returns:
39
  str: the summary of the batch
40
  """
 
44
 
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
+
48
+ global_attention_mask = torch.zeros_like(attention_mask)
49
+ # put global attention on <s> token
50
+ global_attention_mask[:, 0] = 1
51
+
52
+ summary_pred_ids = model.generate(
53
+ input_ids,
54
+ attention_mask=attention_mask,
55
+ global_attention_mask=global_attention_mask,
56
+ output_scores=True,
57
+ return_dict_in_generate=True,
58
+ **kwargs,
59
+ )
 
 
 
 
 
 
 
60
  summary = tokenizer.batch_decode(
61
  summary_pred_ids.sequences,
62
  skip_special_tokens=True,
63
  remove_invalid_values=True,
64
  )
65
+ score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
66
+
67
+ return summary, score
68
 
69
 
70
  def summarize_via_tokenbatches(
 
110
  pbar = tqdm(total=len(in_id_arr))
111
 
112
  for _id, _mask in zip(in_id_arr, att_arr):
113
+
114
+ result, score = summarize_and_score(
 
 
 
 
 
115
  ids=_id,
116
  mask=_mask,
117
  model=model,
 
118
  tokenizer=tokenizer,
119
  **kwargs,
120
  )
121
+ score = round(float(score), 4)
122
  _sum = {
123
  "input_tokens": _id,
124
  "summary": result,
125
+ "summary_score": score,
126
  }
127
  gen_summaries.append(_sum)
128
+ print(f"\t{result[0]}\nScore:\t{score}")
129
  pbar.update()
130
 
131
  pbar.close()