Spaces:
Sleeping
Sleeping
kz209
commited on
Commit
•
203771e
1
Parent(s):
68c64e4
update
Browse files- utils/model.py +5 -4
utils/model.py
CHANGED
@@ -23,16 +23,17 @@ class Model(torch.nn.Module):
|
|
23 |
|
24 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
self.name = model_name
|
|
|
26 |
logging.info(f'start loading model {self.name}')
|
27 |
self.pipeline = transformers.pipeline(
|
28 |
-
"summarization",
|
29 |
model=model_name,
|
30 |
tokenizer=self.tokenizer,
|
31 |
torch_dtype=torch.bfloat16,
|
32 |
device_map="auto",
|
33 |
)
|
34 |
-
|
35 |
logging.info(f'Loaded model {self.name}')
|
|
|
36 |
self.update()
|
37 |
|
38 |
@classmethod
|
@@ -58,6 +59,7 @@ class Model(torch.nn.Module):
|
|
58 |
num_return_sequences=1,
|
59 |
eos_token_id=self.tokenizer.eos_token_id,
|
60 |
)
|
|
|
61 |
else:
|
62 |
sequences = self.pipeline(
|
63 |
content,
|
@@ -68,5 +70,4 @@ class Model(torch.nn.Module):
|
|
68 |
eos_token_id=self.tokenizer.eos_token_id,
|
69 |
return_full_text=False
|
70 |
)
|
71 |
-
|
72 |
-
return sequences[-1]['summary_text']
|
|
|
23 |
|
24 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
self.name = model_name
|
26 |
+
|
27 |
logging.info(f'start loading model {self.name}')
|
28 |
self.pipeline = transformers.pipeline(
|
29 |
+
"summarization" if model_name=="google-t5/t5-large" else "text-generation",
|
30 |
model=model_name,
|
31 |
tokenizer=self.tokenizer,
|
32 |
torch_dtype=torch.bfloat16,
|
33 |
device_map="auto",
|
34 |
)
|
|
|
35 |
logging.info(f'Loaded model {self.name}')
|
36 |
+
|
37 |
self.update()
|
38 |
|
39 |
@classmethod
|
|
|
59 |
num_return_sequences=1,
|
60 |
eos_token_id=self.tokenizer.eos_token_id,
|
61 |
)
|
62 |
+
return sequences[-1]['summary_text']
|
63 |
else:
|
64 |
sequences = self.pipeline(
|
65 |
content,
|
|
|
70 |
eos_token_id=self.tokenizer.eos_token_id,
|
71 |
return_full_text=False
|
72 |
)
|
73 |
+
return sequences[-1]['generated_text']
|
|