freemt commited on
Commit
c471598
1 Parent(s): 0af688a

Update async aconvbot

Browse files
README.md CHANGED
@@ -24,6 +24,15 @@ prin(convertbot("How are you?"))
24
  # I am good # or along that line
25
  ```
26
 
 
 
 
 
 
 
 
 
 
27
  Interactive
28
 
29
  ```bash
@@ -31,4 +40,4 @@ python -m convbot
31
  ```
32
  ## Not tested in Windows 10 and Mac
33
 
34
- The module uses pytorch that is installed differently in Windows than in Linux. To run in Windows or Mac, you can probably just try to install pytorch manually.
24
  # I am good # or along that line
25
  ```
26
 
27
+ The async version `aconvbot`, potentialy for `fastapi` or `Nonebot` plugins and such, is rather artificial since it's based on `ThreadPoolExecutor`. Hence it's not intended for production. You probably should not spawn too many instances.
28
+ ```python
29
+ from convbot import aconvbot
30
+
31
+ async def afunc(text):
32
+ resp = await aconvbot(text)
33
+ ...
34
+ ```
35
+
36
  Interactive
37
 
38
  ```bash
40
  ```
41
  ## Not tested in Windows 10 and Mac
42
 
43
+ The module uses pytorch that is installed differently in Windows than in Linux. To run `convbot` in Windows or Mac, you may give it a spin by cloning the repo (git clone [https://github.com/ffreemt/convbot](https://github.com/ffreemt/convbot)) and installing pytorch manually.
convbot/__init__.py CHANGED
@@ -1,5 +1,8 @@
1
  """Init."""
2
  __version__ = "0.1.0"
3
- from .convbot import convbot
4
 
5
- __all__ = ("convbot",)
 
 
 
1
  """Init."""
2
  __version__ = "0.1.0"
3
+ from .convbot import convbot, aconvbot
4
 
5
+ __all__ = (
6
+ "convbot",
7
+ "aconvbot",
8
+ )
convbot/convbot.py CHANGED
@@ -4,6 +4,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from logzero import logger
6
 
 
 
7
  # model_name = "microsoft/DialoGPT-large"
8
  # model_name = "microsoft/DialoGPT-small"
9
  # pylint: disable=invalid-name
@@ -39,7 +41,7 @@ def _convbot(
39
  chat_history_ids = ""
40
 
41
  input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
42
- if chat_history_ids:
43
  bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
44
  else:
45
  bot_input_ids = input_ids
@@ -113,6 +115,24 @@ def convbot(
113
  return resp
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def main():
117
  print("Bot: Talk to me")
118
  while 1:
4
  import torch
5
  from logzero import logger
6
 
7
+ from .force_async import force_async
8
+
9
  # model_name = "microsoft/DialoGPT-large"
10
  # model_name = "microsoft/DialoGPT-small"
11
  # pylint: disable=invalid-name
41
  chat_history_ids = ""
42
 
43
  input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
44
+ if isinstance(chat_history_ids, torch.Tensor):
45
  bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
46
  else:
47
  bot_input_ids = input_ids
115
  return resp
116
 
117
 
118
+ @force_async
119
+ def aconvbot(
120
+ text: str,
121
+ n_retries: int = 3,
122
+ max_length: int = 1000,
123
+ do_sample: bool = True,
124
+ top_p: float = 0.95,
125
+ top_k: int = 0,
126
+ temperature: float = 0.75,
127
+ ) -> str:
128
+ try:
129
+ _ = convbot(text,n_retries, max_length, do_sample, top_p, top_k, temperature)
130
+ except Exception as e:
131
+ logger.error(e)
132
+ raise
133
+ return _
134
+
135
+
136
  def main():
137
  print("Bot: Talk to me")
138
  while 1:
convbot/force_async.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Turn a sync func to async."""
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import asyncio
4
+ import functools
5
+
6
+ def force_async(func):
7
+ """Turn a sync func to async.
8
+
9
+ Args:
10
+ func: a sync func
11
+
12
+ Return:
13
+ async func
14
+
15
+ Usage:
16
+ @force_async
17
+ def normal_func():
18
+ ...
19
+ loop = asyncio.get_event_loop()
20
+ #~ tasks = [sync_loop1(1, 5), sync_loop1(2, 10)]
21
+ #~ res = loop.run_until_complete(asyncio.gather(*tasks)) # OK
22
+ res = loop.run_until_complete(
23
+ asyncio.gather(
24
+ *[
25
+ sync_loop1(1, 7),
26
+ sync_loop1(2, 6),
27
+ sync_loop1(2, 6),
28
+ async_func,
29
+ ]
30
+ )
31
+ )
32
+ """
33
+ # executor = ThreadPoolExecutor()
34
+ # from concurrent.futures import ThreadPoolExecutor
35
+ executor = ThreadPoolExecutor(max_workers=10)
36
+
37
+ @functools.wraps(func)
38
+ def wrapper(*args, **kwargs):
39
+ """Preserve func info."""
40
+ future = executor.submit(func, *args, **kwargs)
41
+ return asyncio.wrap_future(future) # make it awaitable
42
+
43
+ return wrapper
dist/convbot-0.1.0-py3-none-any.whl ADDED
Binary file (4.83 kB). View file
dist/convbot-0.1.0.tar.gz ADDED
Binary file (3.97 kB). View file
tests/test_convbot.py CHANGED
@@ -1,6 +1,11 @@
1
  """Test convbot."""
 
 
 
2
  from convbot import __version__
3
- from convbot import convbot
 
 
4
 
5
 
6
  def test_version():
@@ -11,12 +16,26 @@ def test_version():
11
  def test_sanity():
12
  """Sanity check."""
13
  try:
14
- assert not convbot()
15
  except Exception:
16
  assert True
17
 
18
 
19
  def test_convbot():
 
20
  resp = convbot("How are you?")
21
  assert len(resp) > 3
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Test convbot."""
2
+ # import asyncio
3
+ import pytest
4
+
5
  from convbot import __version__
6
+ from convbot import convbot, aconvbot
7
+
8
+ pytestmark = pytest.mark.asyncio
9
 
10
 
11
  def test_version():
16
  def test_sanity():
17
  """Sanity check."""
18
  try:
19
+ assert not convbot("")
20
  except Exception:
21
  assert True
22
 
23
 
24
  def test_convbot():
25
+ """Test convbot."""
26
  resp = convbot("How are you?")
27
  assert len(resp) > 3
28
 
29
+ # 2nd call uses chat_history_ids
30
+ resp = convbot("How old are you?")
31
+ assert len(resp) > 3
32
+
33
+
34
+ async def tests_aconvbot():
35
+ """Test aconvbot."""
36
+ resp = await aconvbot("How are you?")
37
+ assert len(resp) > 3
38
+
39
+ # 2nd call uses chat_history_ids
40
+ resp = await aconvbot("How old are you?")
41
+ assert len(resp) > 3