freemt commited on
Commit
7381c93
1 Parent(s): b78ad07

First commit

Browse files
Files changed (3) hide show
  1. README.md +6 -1
  2. convbot/__main__.py +21 -0
  3. convbot/convbot.py +30 -29
README.md CHANGED
@@ -20,7 +20,12 @@ pip install convbot -U
20
  ```python
21
  from convbot import convbot
22
 
23
- convertbot("How are you?")
24
  # I am good # or along that line
 
 
 
25
 
 
 
26
  ```
20
  ```python
21
  from convbot import convbot
22
 
23
+ prin(convertbot("How are you?"))
24
  # I am good # or along that line
25
+ ```
26
+
27
+ Interactive
28
 
29
+ ```bash
30
+ python -m convbot
31
  ```
convbot/__main__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run main.
2
+
3
+ python -m convbot
4
+ """
5
+ from convbot import convbot
6
+
7
+
8
+ def main():
9
+ print("Bot: Talk to me (type quit to exit)")
10
+ while 1:
11
+ text = input("You: ")
12
+
13
+ if text.lower().strip() in ["quit", "exit", "stop"]:
14
+ break
15
+
16
+ resp = convbot(text)
17
+ print("Bot: ", resp)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
convbot/convbot.py CHANGED
@@ -13,12 +13,12 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
 
15
  def _convbot(
16
- text: str,
17
- max_length: int = 1000,
18
- do_sample: bool = True,
19
- top_p: float = 0.95,
20
- top_k: int = 0,
21
- temperature: float = 0.75,
22
  ) -> str:
23
  """Generate a reponse.
24
 
@@ -40,7 +40,7 @@ def _convbot(
40
 
41
  input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
42
  if chat_history_ids:
43
- bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
44
  else:
45
  bot_input_ids = input_ids
46
 
@@ -52,27 +52,29 @@ def _convbot(
52
  top_p=top_p,
53
  top_k=top_k,
54
  temperature=temperature,
55
- pad_token_id=tokenizer.eos_token_id
56
  )
57
 
58
- output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
59
  _convbot.chat_history_ids = chat_history_ids
60
 
61
  return output
62
 
63
 
64
  def convbot(
65
- text: str,
66
- n_reties: int = 3,
67
- max_length: int = 1000,
68
- do_sample: bool = True,
69
- top_p: float = 0.95,
70
- top_k: int = 0,
71
- temperature: float = 0.75,
72
  ) -> str:
73
  """Generate a response."""
74
  try:
75
- n_reties = int(n_reties)
76
  except Exception as e:
77
  logger.error(e)
78
  raise
@@ -81,27 +83,27 @@ def convbot(
81
  except AttributeError:
82
  prev_resp = ""
83
 
84
- resp = _convbot(text, max_length, top_p, top_p, temperature)
85
-
86
- # retry n_retires if resp is empty
87
  if not resp.strip():
88
  idx = 0
89
- while idx < n_retires:
90
  idx += 1
91
  _convbot.chat_history_ids = ""
92
- resp = _convbot(text, max_length, top_p, top_p, temperature)
93
  if resp.strip():
94
  break
95
  else:
96
  logger.warning("bot acting up (empty response), something has gone awry")
97
-
98
  # check repeated responses
99
- if resp.strip() == convbot.prev_resp:
100
  idx = 0
101
- while idx < n_retires:
102
  idx += 1
103
- resp = _convbot(text, max_length, top_p, top_p, temperature)
104
- if resp.strip() != convbot.prev_resp:
105
  break
106
  else:
107
  logger.warning("bot acting up (repeating), something has gone awry")
@@ -112,7 +114,7 @@ def convbot(
112
 
113
 
114
  def main():
115
- print("Talk to me")
116
  while 1:
117
  text = input("You: ")
118
  resp = _convbot(text)
@@ -121,4 +123,3 @@ def main():
121
 
122
  if __name__ == "__main__":
123
  main()
124
-
13
 
14
 
15
  def _convbot(
16
+ text: str,
17
+ max_length: int = 1000,
18
+ do_sample: bool = True,
19
+ top_p: float = 0.95,
20
+ top_k: int = 0,
21
+ temperature: float = 0.75,
22
  ) -> str:
23
  """Generate a reponse.
24
 
40
 
41
  input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
42
  if chat_history_ids:
43
+ bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
44
  else:
45
  bot_input_ids = input_ids
46
 
52
  top_p=top_p,
53
  top_k=top_k,
54
  temperature=temperature,
55
+ pad_token_id=tokenizer.eos_token_id,
56
  )
57
 
58
+ output = tokenizer.decode(
59
+ chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True
60
+ )
61
  _convbot.chat_history_ids = chat_history_ids
62
 
63
  return output
64
 
65
 
66
  def convbot(
67
+ text: str,
68
+ n_retries: int = 3,
69
+ max_length: int = 1000,
70
+ do_sample: bool = True,
71
+ top_p: float = 0.95,
72
+ top_k: int = 0,
73
+ temperature: float = 0.75,
74
  ) -> str:
75
  """Generate a response."""
76
  try:
77
+ n_retries = int(n_retries)
78
  except Exception as e:
79
  logger.error(e)
80
  raise
83
  except AttributeError:
84
  prev_resp = ""
85
 
86
+ resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
87
+
88
+ # retry n_retries if resp is empty
89
  if not resp.strip():
90
  idx = 0
91
+ while idx < n_retries:
92
  idx += 1
93
  _convbot.chat_history_ids = ""
94
+ resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
95
  if resp.strip():
96
  break
97
  else:
98
  logger.warning("bot acting up (empty response), something has gone awry")
99
+
100
  # check repeated responses
101
+ if resp.strip() == prev_resp:
102
  idx = 0
103
+ while idx < n_retries:
104
  idx += 1
105
+ resp = _convbot(text, max_length, do_sample, top_p, top_k, temperature)
106
+ if resp.strip() != prev_resp:
107
  break
108
  else:
109
  logger.warning("bot acting up (repeating), something has gone awry")
114
 
115
 
116
  def main():
117
+ print("Bot: Talk to me")
118
  while 1:
119
  text = input("You: ")
120
  resp = _convbot(text)
123
 
124
  if __name__ == "__main__":
125
  main()