kz209 commited on
Commit
203771e
1 Parent(s): 68c64e4
Files changed (1) hide show
  1. utils/model.py +5 -4
utils/model.py CHANGED
@@ -23,16 +23,17 @@ class Model(torch.nn.Module):
23
 
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
 
26
  logging.info(f'start loading model {self.name}')
27
  self.pipeline = transformers.pipeline(
28
- "summarization",
29
  model=model_name,
30
  tokenizer=self.tokenizer,
31
  torch_dtype=torch.bfloat16,
32
  device_map="auto",
33
  )
34
-
35
  logging.info(f'Loaded model {self.name}')
 
36
  self.update()
37
 
38
  @classmethod
@@ -58,6 +59,7 @@ class Model(torch.nn.Module):
58
  num_return_sequences=1,
59
  eos_token_id=self.tokenizer.eos_token_id,
60
  )
 
61
  else:
62
  sequences = self.pipeline(
63
  content,
@@ -68,5 +70,4 @@ class Model(torch.nn.Module):
68
  eos_token_id=self.tokenizer.eos_token_id,
69
  return_full_text=False
70
  )
71
-
72
- return sequences[-1]['summary_text']
 
23
 
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
26
+
27
  logging.info(f'start loading model {self.name}')
28
  self.pipeline = transformers.pipeline(
29
+ "summarization" if model_name=="google-t5/t5-large" else "text-generation",
30
  model=model_name,
31
  tokenizer=self.tokenizer,
32
  torch_dtype=torch.bfloat16,
33
  device_map="auto",
34
  )
 
35
  logging.info(f'Loaded model {self.name}')
36
+
37
  self.update()
38
 
39
  @classmethod
 
59
  num_return_sequences=1,
60
  eos_token_id=self.tokenizer.eos_token_id,
61
  )
62
+ return sequences[-1]['summary_text']
63
  else:
64
  sequences = self.pipeline(
65
  content,
 
70
  eos_token_id=self.tokenizer.eos_token_id,
71
  return_full_text=False
72
  )
73
+ return sequences[-1]['generated_text']