chiliu commited on
Commit
faed5af
1 Parent(s): e54792e
Files changed (2) hide show
  1. README.md +1 -1
  2. mamba_gpt_pipeline.py +42 -0
README.md CHANGED
@@ -83,7 +83,7 @@ Alternatively, you can download the mamba_gpt_pipeline.py, store it alongside yo
83
 
84
  ```python
85
  import torch
86
- from mamba_gpt_pipeline.py import MambaGPTTextGenerationPipeline
87
  from transformers import AutoModelForCausalLM, AutoTokenizer
88
 
89
  tokenizer = AutoTokenizer.from_pretrained(
 
83
 
84
  ```python
85
  import torch
86
+ from mamba_gpt_pipeline import MambaGPTTextGenerationPipeline
87
  from transformers import AutoModelForCausalLM, AutoTokenizer
88
 
89
  tokenizer = AutoTokenizer.from_pretrained(
mamba_gpt_pipeline.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextGenerationPipeline
2
+ from transformers.pipelines.text_generation import ReturnType
3
+
4
+ STYLE = "<|prompt|>{instruction}</s><|answer|>"
5
+
6
+
7
+ class MambaGPTTextGenerationPipeline(TextGenerationPipeline):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.prompt = STYLE
11
+
12
+ def preprocess(
13
+ self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
14
+ ):
15
+ prompt_text = self.prompt.format(instruction=prompt_text)
16
+ return super().preprocess(
17
+ prompt_text,
18
+ prefix=prefix,
19
+ handle_long_generation=handle_long_generation,
20
+ **generate_kwargs,
21
+ )
22
+
23
+ def postprocess(
24
+ self,
25
+ model_outputs,
26
+ return_type=ReturnType.FULL_TEXT,
27
+ clean_up_tokenization_spaces=True,
28
+ ):
29
+ records = super().postprocess(
30
+ model_outputs,
31
+ return_type=return_type,
32
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
33
+ )
34
+ for rec in records:
35
+ rec["generated_text"] = (
36
+ rec["generated_text"]
37
+ .split("<|answer|>")[1]
38
+ .strip()
39
+ .split("<|prompt|>")[0]
40
+ .strip()
41
+ )
42
+ return records