IlyaGusev commited on
Commit
6f3d67b
1 Parent(s): 2cb28a0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -10
README.md CHANGED
@@ -37,21 +37,15 @@ class Conversation:
37
  self,
38
  message_template=DEFAULT_MESSAGE_TEMPLATE,
39
  system_prompt=DEFAULT_SYSTEM_PROMPT,
40
- start_token_id=1,
41
  ):
42
  self.message_template = message_template
43
- self.start_token_id = start_token_id
44
  self.messages = [{
45
  "role": "system",
46
  "content": system_prompt
47
  }]
48
 
49
- def get_start_token_id(self):
50
- return self.start_token_id
51
-
52
- def get_bot_token_id(self):
53
- return self.bot_token_id
54
-
55
  def add_user_message(self, message):
56
  self.messages.append({
57
  "role": "user",
@@ -69,12 +63,12 @@ class Conversation:
69
  for message in self.messages:
70
  message_text = self.message_template.format(**message)
71
  final_text += message_text
72
- final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
73
  return final_text.strip()
74
 
75
 
76
  def generate(model, tokenizer, prompt, generation_config):
77
- data = tokenizer(prompt, return_tensors="pt")
78
  data = {k: v.to(model.device) for k, v in data.items()}
79
  output_ids = model.generate(
80
  **data,
 
37
  self,
38
  message_template=DEFAULT_MESSAGE_TEMPLATE,
39
  system_prompt=DEFAULT_SYSTEM_PROMPT,
40
+ response_template=DEFAULT_RESPONSE_TEMPLATE
41
  ):
42
  self.message_template = message_template
43
+ self.response_template = response_template
44
  self.messages = [{
45
  "role": "system",
46
  "content": system_prompt
47
  }]
48
 
 
 
 
 
 
 
49
  def add_user_message(self, message):
50
  self.messages.append({
51
  "role": "user",
 
63
  for message in self.messages:
64
  message_text = self.message_template.format(**message)
65
  final_text += message_text
66
+ final_text += DEFAULT_RESPONSE_TEMPLATE
67
  return final_text.strip()
68
 
69
 
70
  def generate(model, tokenizer, prompt, generation_config):
71
+ data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
72
  data = {k: v.to(model.device) for k, v in data.items()}
73
  output_ids = model.generate(
74
  **data,