winglian commited on
Commit
d353b2f
1 Parent(s): 6398310

improved stop token for supercot

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -38,10 +38,11 @@ def prompt_chat(system_msg, history):
38
  class Pipeline:
39
  prefer_async = True
40
 
41
- def __init__(self, endpoint_id, name, prompt_fn):
42
  self.endpoint_id = endpoint_id
43
  self.name = name
44
  self.prompt_fn = prompt_fn
 
45
  self.generation_config = {
46
  "max_new_tokens": 1024,
47
  "top_k": 40,
@@ -52,7 +53,7 @@ class Pipeline:
52
  "seed": -1,
53
  "batch_size": 8,
54
  "threads": -1,
55
- "stop": ["</s>", "USER:", "### Instruction:"],
56
  }
57
 
58
  def __call__(self, prompt):
@@ -102,7 +103,7 @@ AVAILABLE_MODELS = {
102
  "hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
103
  "manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
104
  "airoboros-13b": ("rglzxnk80660ja", prompt_chat),
105
- "supercot-13b": ("0be7865dwxpwqk", prompt_instruct),
106
  "mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
107
  }
108
 
@@ -111,7 +112,10 @@ _memoized_models = defaultdict()
111
 
112
  def get_model_pipeline(model_name):
113
  if not _memoized_models.get(model_name):
114
- _memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1])
 
 
 
115
  return _memoized_models.get(model_name)
116
 
117
  start_message = """- The Assistant is helpful and transparent.
 
38
  class Pipeline:
39
  prefer_async = True
40
 
41
+ def __init__(self, endpoint_id, name, prompt_fn, stop_tokens=None):
42
  self.endpoint_id = endpoint_id
43
  self.name = name
44
  self.prompt_fn = prompt_fn
45
+ stop_tokens = stop_tokens or []
46
  self.generation_config = {
47
  "max_new_tokens": 1024,
48
  "top_k": 40,
 
53
  "seed": -1,
54
  "batch_size": 8,
55
  "threads": -1,
56
+ "stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
57
  }
58
 
59
  def __call__(self, prompt):
 
103
  "hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
104
  "manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
105
  "airoboros-13b": ("rglzxnk80660ja", prompt_chat),
106
+ "supercot-13b": ("0be7865dwxpwqk", prompt_instruct, ["Instruction:"]),
107
  "mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
108
  }
109
 
 
112
 
113
  def get_model_pipeline(model_name):
114
  if not _memoized_models.get(model_name):
115
+ kwargs = {}
116
+ if len(AVAILABLE_MODELS[model_name]) >= 3:
117
+ kwargs["stop_tokens"] = AVAILABLE_MODELS[model_name][2]
118
+ _memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1], **kwargs)
119
  return _memoized_models.get(model_name)
120
 
121
  start_message = """- The Assistant is helpful and transparent.