arnocandel commited on
Commit
d1aea17
1 Parent(s): 3dadeda

https://github.com/h2oai/h2ogpt/issues/125#issuecomment-1548239108

Browse files
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "h2oai/h2ogpt-oasst1-512-12b",
3
  "architectures": [
4
  "GPTNeoXForCausalLM"
5
  ],
@@ -10,12 +10,6 @@
10
  "pt": "AutoModelForCausalLM"
11
  }
12
  },
13
- "custom_pipelines": {
14
- "text-generation": {
15
- "impl": "h2oai_pipeline.H2OTextGenerationPipeline",
16
- "pt": "AutoModelForCausalLM"
17
- }
18
- },
19
  "eos_token_id": 0,
20
  "hidden_act": "gelu",
21
  "hidden_size": 5120,
 
1
  {
2
+ "_name_or_path": "EleutherAI/pythia-12b-deduped",
3
  "architectures": [
4
  "GPTNeoXForCausalLM"
5
  ],
 
10
  "pt": "AutoModelForCausalLM"
11
  }
12
  },
 
 
 
 
 
 
13
  "eos_token_id": 0,
14
  "hidden_act": "gelu",
15
  "hidden_size": 5120,
h2oai_pipeline.py CHANGED
@@ -1,6 +1,9 @@
1
  from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
 
 
 
4
  human = "<human>:"
5
  bot = "<bot>:"
6
 
@@ -28,3 +31,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
28
  for rec in records:
29
  rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
30
  return records
 
 
 
 
 
 
1
  from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
4
+ from stopping import get_stopping
5
+
6
+ prompt_type = "human_bot"
7
  human = "<human>:"
8
  bot = "<bot>:"
9
 
 
31
  for rec in records:
32
  rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
33
  return records
34
+
35
+ def _forward(self, model_inputs, **generate_kwargs):
36
+ stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
37
+ generate_kwargs['stopping_criteria'] = stopping_criteria
38
+ return super()._forward(model_inputs, **generate_kwargs)
pytorch_model-00001-of-00005.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d21435593086b58709fd1598a039524ff6dacca18f996edd120aa05f6d1cbce
3
  size 4957630318
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64691fa6fa33a63aa2fad165e6215a17e79dac4a203b9f8c887907a72278660b
3
  size 4957630318
pytorch_model-00002-of-00005.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aca3ba1593cc4dd0cfd529ec24c4fbf53481ad8e6d5ff9b81ca9d208f2fbedf8
3
  size 4853861544
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75dd532c4cb4c3649e80191dac7f0120ce0d3a0f573f66da11f61290936eeb46
3
  size 4853861544
pytorch_model-00003-of-00005.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ac87c950ae759c15bb50c40263bc26f10a9f106e07e127361c8c7635273f0d1
3
  size 4858068625
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ace0311fd3b140629c0bda15e5d6ebf23987d4905124da10fac7c0ff11e583e
3
  size 4858068625
pytorch_model-00004-of-00005.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:29ad9d68225ceb2fe58373c64a38a8670fffc9e6794d51e2794c5c113e129e89
3
  size 5015385889
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea6b0c3b72599fc88f0328bb9c1f5058cd33eff5a5897c0863cd09d713ffbea1
3
  size 5015385889
pytorch_model-00005-of-00005.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9719e633aa21824604cc52e0cc4b586c173bb49efd3e2ac75fda0201be0cd66c
3
  size 4158379959
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f12d59a0255f9b07e531081be6666b3e9507baa82d25ac412536f6badaffdd
3
  size 4158379959