theodotus commited on
Commit
651c45f
1 Parent(s): 5bd906f

Added dialogpt pipeline

Browse files
Files changed (2) hide show
  1. pipeline.py +26 -0
  2. requirements.txt +1 -0
pipeline.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import ctranslate2
3
+
4
+ from typing import List, Dict
5
+ import os
6
+
7
+
8
+
9
+ class PreTrainedPipeline():
10
+ def __init__(self, path: str):
11
+ # Init DialoGPT
12
+ dialogpt_path = os.path.join(path, "dialogpt")
13
+ self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
14
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
15
+
16
+ def __call__(self, inputs: str) -> List[Dict]:
17
+
18
+ text = inputs + self.tokenizer.eos_token
19
+ start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
20
+
21
+ results = self.generator.generate_batch([start_tokens])
22
+ output = results[0].sequences[0]
23
+
24
+ generated_text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output))
25
+
26
+ return [{"generated_text": generated_text}]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ctranslate2