Benjamin Bossan commited on
Commit
01ae0bb
1 Parent(s): 64d4f97

Use transformers agents where applicable

Browse files
.gitignore CHANGED
@@ -10,3 +10,6 @@ build
10
  htmlcov
11
 
12
  *.db
 
 
 
 
10
  htmlcov
11
 
12
  *.db
13
+ notebooks/
14
+ *.ipynb
15
+ .env
README.md CHANGED
@@ -19,18 +19,21 @@ python -m pip install -e .
19
 
20
  ## Starting
21
 
 
 
 
 
 
22
  In one terminal, start the background worker:
23
 
24
  ```sh
25
- cd src
26
- python worker.py
27
  ```
28
 
29
  In another terminal, start the web server:
30
 
31
  ```sh
32
- cd src
33
- uvicorn webservice:app --reload --port 8080
34
  ```
35
 
36
  For example requests, check `requests.org`.
 
19
 
20
  ## Starting
21
 
22
+ ### Preparing environemnt
23
+
24
+ Set an environemnt variable called "HF_HUB_TOKEN" with your Hugging Face token
25
+ or create a `.env` file with that env var.
26
+
27
  In one terminal, start the background worker:
28
 
29
  ```sh
30
+ python src/gistillery/worker.py
 
31
  ```
32
 
33
  In another terminal, start the web server:
34
 
35
  ```sh
36
+ uvicorn src.gistillery.webservice:app --reload --port 8080
 
37
  ```
38
 
39
  For example requests, check `requests.org`.
pyproject.toml CHANGED
@@ -18,5 +18,5 @@ no_implicit_optional = true
18
  strict = true
19
 
20
  [[tool.mypy.overrides]]
21
- module = "transformers,trafilatura"
22
  ignore_missing_imports = true
 
18
  strict = true
19
 
20
  [[tool.mypy.overrides]]
21
+ module = "huggingface_hub,trafilatura,transformers.*"
22
  ignore_missing_imports = true
requests.org CHANGED
@@ -10,19 +10,18 @@ curl -X 'GET' \
10
  : OK
11
 
12
  #+begin_src bash
13
- # curl command to localhost and post the message "hi there"
14
  curl -X 'POST' \
15
  'http://localhost:8080/submit/' \
16
  -H 'accept: application/json' \
17
  -H 'Content-Type: application/json' \
18
  -d '{
19
  "author": "ben",
20
- "content": "SAN FRANCISCO, May 2, 2023 PRNewswire -- GitLab Inc., the most comprehensive, scalable enterprise DevSecOps platform for software innovation, and Google Cloud today announced an extension of its strategic partnership to deliver secure AI offerings to the enterprise. GitLab is trusted by more than 50% of the Fortune 100 to secure and protect their most valuable assets, and leads with a privacy-first approach to AI. By leveraging Google Cloud'\''s customizable foundation models and open generative AI infrastructure, GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform."
21
  }'
22
  #+end_src
23
 
24
  #+RESULTS:
25
- : Submitted job 04deee1a2a9b4d6ea986ffe0fa4017d9
26
 
27
  #+begin_src bash
28
  curl -X 'POST' \
@@ -31,12 +30,12 @@ curl -X 'POST' \
31
  -H 'Content-Type: application/json' \
32
  -d '{
33
  "author": "ben",
34
- "content": "In literature discussing why ChatGPT is able to capture so much of our imagination, I often come across two narratives: Scale: throwing more data and compute at it. UX: moving from a prompt interface to a more natural chat interface. A narrative that is often glossed over in the demo frenzy is the incredible technical creativity that went into making models like ChatGPT work. One such cool idea is RLHF (Reinforcement Learning from Human Feedback): incorporating reinforcement learning and human feedback into NLP. RL has been notoriously difficult to work with, and therefore, mostly confined to gaming and simulated environments like Atari or MuJoCo. Just five years ago, both RL and NLP were progressing pretty much orthogonally – different stacks, different techniques, and different experimentation setups. It’s impressive to see it work in a new domain at a massive scale. So, how exactly does RLHF work? Why does it work? This post will discuss the answers to those questions."
35
  }'
36
  #+end_src
37
 
38
  #+RESULTS:
39
- : Submitted job 730352e00e8145b39971fdc386c28a8f
40
 
41
  #+begin_src bash
42
  curl -X 'POST' \
@@ -45,21 +44,21 @@ curl -X 'POST' \
45
  -H 'Content-Type: application/json' \
46
  -d '{
47
  "author": "ben",
48
- "content": "https://en.wikipedia.org/wiki/Goulburn_Street"
49
  }'
50
  #+end_src
51
 
52
  #+RESULTS:
53
- : Submitted job 1738d7daa96147198d80b93ea040863d
54
 
55
  #+begin_src bash
56
  curl -X 'GET' \
57
- 'http://localhost:8080/check_job_status/1738d7daa96147198d80b93ea040863d' \
58
  -H 'accept: application/json'
59
  #+end_src
60
 
61
  #+RESULTS:
62
- | {"id":"1738d7daa96147198d80b93ea040863d" | status:"pending" | last_updated:"2023-05-09T13:24:42"} |
63
 
64
  #+begin_src bash
65
  curl -X 'GET' \
@@ -68,4 +67,13 @@ curl -X 'GET' \
68
  #+end_src
69
 
70
  #+RESULTS:
71
- | [{"id":"1738d7daa96147198d80b93ea040863d" | author:"ben" | summary:"Goulburn Street is a street in the central business district of Sydney | New South Wales | Australia. It runs from Darling Harbour and Chinatown in the west to Crown Street in the east at Darlinghurst and Surry Hills. The only car park operated by Sydney City Council within the CBD is at the corner of Goulburn and Elizabeth Streets. It was the first air rights car park in Australia | opening in 1963 over six tracks of the City Circle line.[3][4]" | tags:["#centralbusinessdistrict" | #darlinghurst | #general | #goulburnstreet | #surryhills | #sydney | #sydneymasoniccentre] | date:"2023-05-09T13:24:42"} | {"id":"730352e00e8145b39971fdc386c28a8f" | author:"ben" | summary:"A new approach to NLP that incorporates reinforcement learning and human feedback. How does it work? Why does it work? In this post | I’ll explain how it works. RLHF is a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback." | tags:["#" | #general | #rlhf] | date:"2023-05-09T13:24:38"} | {"id":"04deee1a2a9b4d6ea986ffe0fa4017d9" | author:"ben" | summary:"GitLab | the most comprehensive | scalable enterprise DevSecOps platform for software innovation | and Google Cloud today announced an extension of their strategic partnership to deliver secure AI offerings to the enterprise. By leveraging Google Cloud's customizable foundation models and open generative AI infrastructure | GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform. The company's AI capabilities are designed to help enterprises improve productivity and reduce costs." | tags:["#ai-assistedfeatures" | #enterprisedevsecopsplatform | #general | #gitlab | #googlecloud] | date:"2023-05-09T13:24:36"}] |
 
 
 
 
 
 
 
 
 
 
10
  : OK
11
 
12
  #+begin_src bash
 
13
  curl -X 'POST' \
14
  'http://localhost:8080/submit/' \
15
  -H 'accept: application/json' \
16
  -H 'Content-Type: application/json' \
17
  -d '{
18
  "author": "ben",
19
+ "content": "In literature discussing why ChatGPT is able to capture so much of our imagination, I often come across two narratives: Scale: throwing more data and compute at it. UX: moving from a prompt interface to a more natural chat interface. A narrative that is often glossed over in the demo frenzy is the incredible technical creativity that went into making models like ChatGPT work. One such cool idea is RLHF (Reinforcement Learning from Human Feedback): incorporating reinforcement learning and human feedback into NLP. RL has been notoriously difficult to work with, and therefore, mostly confined to gaming and simulated environments like Atari or MuJoCo. Just five years ago, both RL and NLP were progressing pretty much orthogonally different stacks, different techniques, and different experimentation setups. It’s impressive to see it work in a new domain at a massive scale. So, how exactly does RLHF work? Why does it work? This post will discuss the answers to those questions."
20
  }'
21
  #+end_src
22
 
23
  #+RESULTS:
24
+ : Submitted job fef72c3aa4394bc7a299291c80a5c06b
25
 
26
  #+begin_src bash
27
  curl -X 'POST' \
 
30
  -H 'Content-Type: application/json' \
31
  -d '{
32
  "author": "ben",
33
+ "content": "https://en.wikipedia.org/wiki/Goulburn_Street"
34
  }'
35
  #+end_src
36
 
37
  #+RESULTS:
38
+ : Submitted job f37729bb36104ab4a23cefd0480e4862
39
 
40
  #+begin_src bash
41
  curl -X 'POST' \
 
44
  -H 'Content-Type: application/json' \
45
  -d '{
46
  "author": "ben",
47
+ "content": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e1/Cattle_tyrant_%28Machetornis_rixosa%29_on_Capybara.jpg/1920px-Cattle_tyrant_%28Machetornis_rixosa%29_on_Capybara.jpg"
48
  }'
49
  #+end_src
50
 
51
  #+RESULTS:
52
+ : Submitted job dc3da7b1d5aa47c38dc6713952104f5f
53
 
54
  #+begin_src bash
55
  curl -X 'GET' \
56
+ 'http://localhost:8080/check_job_status/' \
57
  -H 'accept: application/json'
58
  #+end_src
59
 
60
  #+RESULTS:
61
+ : Found 3 pending job(s): fef72c3aa4394bc7a299291c80a5c06b, f37729bb36104ab4a23cefd0480e4862, dc3da7b1d5aa47c38dc6713952104f5f
62
 
63
  #+begin_src bash
64
  curl -X 'GET' \
 
67
  #+end_src
68
 
69
  #+RESULTS:
70
+ | [{"id":"dc3da7b1d5aa47c38dc6713952104f5f" | author:"ben" | summary:"A small bird is perched on the back of a capy capy. It's looking for a place to nestle. It doesn't seem to be finding a suitable place for it | though | because it's not very big. The place is not very flat. " | tags:["#back" | #bird | #capy | #general | #perch | #perched] | date:"2023-05-11T13:16:48"} | {"id":"f37729bb36104ab4a23cefd0480e4862" | author:"ben" | summary:"Goulburn Street is a street in the central business district of Sydney in New South Wales | Australia. It runs from Darling Harbour and Chinatown in the west to Crown Street in the east at Darlinghurst and Surry Hills. It is the only car park operated by Sydney City Council within the CBD and was the first air rights car park in Australia." | tags:["#centralbusinessdistrict" | #darlinghurst | #general | #goulburnstreet | #surryhills | #sydney | #sydneymasoniccentre] | date:"2023-05-11T13:16:47"} | {"id":"fef72c3aa4394bc7a299291c80a5c06b" | author:"ben" | summary:"ChatGPT is able to capture our imagination because of its scale. RLHF (Reinforcement Learning from Human Feedback) is a new approach to NLP that incorporates reinforcement learning and human feedback into NLP. It's impressive to see it work in a new domain at a massive scale." | tags:["#" | #general | #rlhf] | date:"2023-05-11T13:16:45"}] |
71
+
72
+ #+begin_src bash
73
+ curl -X 'GET' \
74
+ 'http://localhost:8080/recent/rlhf' \
75
+ -H 'accept: application/json'
76
+ #+end_src
77
+
78
+ #+RESULTS:
79
+ | [{"id":"fef72c3aa4394bc7a299291c80a5c06b" | author:"ben" | summary:"ChatGPT is able to capture our imagination because of its scale. RLHF (Reinforcement Learning from Human Feedback) is a new approach to NLP that incorporates reinforcement learning and human feedback into NLP. It's impressive to see it work in a new domain at a massive scale." | tags:["#" | #general | #rlhf] | date:"2023-05-11T13:16:45"}] |
requirements-dev.txt CHANGED
@@ -4,3 +4,4 @@ mypy
4
  ruff
5
  pytest
6
  pytest-cov
 
 
4
  ruff
5
  pytest
6
  pytest-cov
7
+ types-Pillow
requirements.txt CHANGED
@@ -2,6 +2,8 @@ fastapi
2
  httpx
3
  uvicorn[standard]
4
  torch
5
- transformers
 
6
  charset-normalizer
7
  trafilatura
 
 
2
  httpx
3
  uvicorn[standard]
4
  torch
5
+ transformers>=4.29.0
6
+ accelerate
7
  charset-normalizer
8
  trafilatura
9
+ pillow
src/gistillery/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from pydantic import BaseSettings
5
+
6
+
7
+ class Config(BaseSettings):
8
+ hf_hub_token: str = "missing"
9
+ hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
10
+ db_file_name: Path = Path("sqlite-data.db")
11
+
12
+ class Config:
13
+ # load .env file by default, with provisio to use other .env files if set
14
+ env_file = os.getenv('ENV_FILE', '.env')
15
+
16
+
17
+ _config = None
18
+
19
+
20
+ def get_config() -> Config:
21
+ global _config
22
+ if _config is None:
23
+ _config = Config()
24
+ return _config
src/gistillery/db.py CHANGED
@@ -1,15 +1,14 @@
1
  import logging
2
- import os
3
  import sqlite3
4
  from collections import namedtuple
5
  from contextlib import contextmanager
6
  from typing import Generator
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
  logger.setLevel(logging.DEBUG)
10
 
11
- db_file = os.getenv("DB_FILE_NAME", "sqlite-data.db")
12
-
13
 
14
  schema_entries = """
15
  CREATE TABLE entries
@@ -91,7 +90,7 @@ def _get_db_connection() -> sqlite3.Connection:
91
  global TABLES_CREATED
92
 
93
  # sqlite cannot deal with concurrent access, so we set a big timeout
94
- conn = sqlite3.connect(db_file, timeout=30)
95
  conn.row_factory = namedtuple_factory
96
  if TABLES_CREATED:
97
  return conn
 
1
  import logging
 
2
  import sqlite3
3
  from collections import namedtuple
4
  from contextlib import contextmanager
5
  from typing import Generator
6
 
7
+ from gistillery.config import get_config
8
+
9
  logger = logging.getLogger(__name__)
10
  logger.setLevel(logging.DEBUG)
11
 
 
 
12
 
13
  schema_entries = """
14
  CREATE TABLE entries
 
90
  global TABLES_CREATED
91
 
92
  # sqlite cannot deal with concurrent access, so we set a big timeout
93
+ conn = sqlite3.connect(get_config().db_file_name, timeout=30)
94
  conn.row_factory = namedtuple_factory
95
  if TABLES_CREATED:
96
  return conn
src/gistillery/preprocessing.py CHANGED
@@ -1,16 +1,33 @@
1
  import abc
 
2
  import logging
3
  import re
 
4
 
 
5
  from httpx import Client
6
- from trafilatura import extract
 
7
 
8
  from gistillery.base import JobInput
 
 
9
 
10
  logger = logging.getLogger(__name__)
11
  logger.setLevel(logging.DEBUG)
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Processor(abc.ABC):
15
  def get_name(self) -> str:
16
  return self.__class__.__name__
@@ -40,25 +57,55 @@ class RawTextProcessor(Processor):
40
 
41
 
42
  class DefaultUrlProcessor(Processor):
43
- # uses trafilatura to extract text from html
44
  def __init__(self) -> None:
45
  self.client = Client()
46
- self.regex = re.compile(r"(https?://[^\s]+)")
47
- self.url = None
48
  self.template = "{url}\n\n{content}"
49
 
50
  def match(self, input: JobInput) -> bool:
51
- urls = list(self.regex.findall(input.content.strip()))
52
- if len(urls) == 1:
53
- self.url = urls[0]
54
- return True
55
- return False
 
56
 
57
  def process(self, input: JobInput) -> str:
58
  """Get content of website and return it as string"""
59
- assert isinstance(self.url, str)
 
 
60
  text = self.client.get(self.url).text
61
  assert isinstance(text, str)
62
- extracted = extract(text)
63
  text = self.template.format(url=self.url, content=extracted)
64
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
+ import io
3
  import logging
4
  import re
5
+ from typing import Optional
6
 
7
+ import trafilatura
8
  from httpx import Client
9
+
10
+ from PIL import Image
11
 
12
  from gistillery.base import JobInput
13
+ from gistillery.tools import get_agent
14
+
15
 
16
  logger = logging.getLogger(__name__)
17
  logger.setLevel(logging.DEBUG)
18
 
19
 
20
+ RE_URL = re.compile(r"(https?://[^\s]+)")
21
+
22
+
23
+ def get_url(text: str) -> str | None:
24
+ urls: list[str] = list(RE_URL.findall(text))
25
+ if len(urls) == 1:
26
+ url = urls[0]
27
+ return url
28
+ return None
29
+
30
+
31
  class Processor(abc.ABC):
32
  def get_name(self) -> str:
33
  return self.__class__.__name__
 
57
 
58
 
59
  class DefaultUrlProcessor(Processor):
 
60
  def __init__(self) -> None:
61
  self.client = Client()
62
+ self.url = Optional[str]
 
63
  self.template = "{url}\n\n{content}"
64
 
65
  def match(self, input: JobInput) -> bool:
66
+ url = get_url(input.content.strip())
67
+ if url is None:
68
+ return False
69
+
70
+ self.url = url
71
+ return True
72
 
73
  def process(self, input: JobInput) -> str:
74
  """Get content of website and return it as string"""
75
+ if not isinstance(self.url, str):
76
+ raise TypeError("self.url must be a string")
77
+
78
  text = self.client.get(self.url).text
79
  assert isinstance(text, str)
80
+ extracted = trafilatura.extract(text)
81
  text = self.template.format(url=self.url, content=extracted)
82
+ return str(text)
83
+
84
+
85
+ class ImageUrlProcessor(Processor):
86
+ def __init__(self) -> None:
87
+ self.client = Client()
88
+ self.url = Optional[str]
89
+ self.template = "{url}\n\n{content}"
90
+ self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'}
91
+
92
+ def match(self, input: JobInput) -> bool:
93
+ url = get_url(input.content.strip())
94
+ if url is None:
95
+ return False
96
+
97
+ suffix = url.rsplit(".", 1)[-1].lower()
98
+ if suffix not in self.image_suffixes:
99
+ return False
100
+
101
+ self.url = url
102
+ return True
103
+
104
+ def process(self, input: JobInput) -> str:
105
+ if not isinstance(self.url, str):
106
+ raise TypeError("self.url must be a string")
107
+
108
+ response = self.client.get(self.url)
109
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
110
+ caption = get_agent().run("Caption the following image", image=image)
111
+ return str(caption)
src/gistillery/registry.py CHANGED
@@ -1,10 +1,14 @@
1
- from gistillery.ml import Summarizer, Tagger
2
- from gistillery.preprocessing import Processor, RawTextProcessor
3
-
4
  from gistillery.base import JobInput
 
 
 
 
 
 
 
5
 
6
 
7
- class MlRegistry:
8
  def __init__(self) -> None:
9
  self.processors: list[Processor] = []
10
  self.summerizer: Summarizer | None = None
@@ -39,3 +43,24 @@ class MlRegistry:
39
  def get_tagger(self) -> Tagger:
40
  assert self.tagger
41
  return self.tagger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from gistillery.base import JobInput
2
+ from gistillery.tools import Summarizer, Tagger, HfDefaultSummarizer, HfDefaultTagger
3
+ from gistillery.preprocessing import (
4
+ Processor,
5
+ RawTextProcessor,
6
+ ImageUrlProcessor,
7
+ DefaultUrlProcessor,
8
+ )
9
 
10
 
11
+ class ToolRegistry:
12
  def __init__(self) -> None:
13
  self.processors: list[Processor] = []
14
  self.summerizer: Summarizer | None = None
 
43
  def get_tagger(self) -> Tagger:
44
  assert self.tagger
45
  return self.tagger
46
+
47
+
48
+ _registry = None
49
+
50
+
51
+ def get_tool_registry() -> ToolRegistry:
52
+ global _registry
53
+ if _registry is not None:
54
+ return _registry
55
+
56
+ summarizer = HfDefaultSummarizer()
57
+ tagger = HfDefaultTagger()
58
+
59
+ _registry = ToolRegistry()
60
+ _registry.register_processor(ImageUrlProcessor())
61
+ _registry.register_processor(DefaultUrlProcessor())
62
+ _registry.register_processor(RawTextProcessor())
63
+ _registry.register_summarizer(summarizer)
64
+ _registry.register_tagger(tagger)
65
+
66
+ return _registry
src/gistillery/{ml.py → tools.py} RENAMED
@@ -1,17 +1,26 @@
1
  import abc
2
- from typing import Any
3
- import logging
4
 
5
- logger = logging.getLogger(__name__)
6
- logger.setLevel(logging.DEBUG)
 
 
7
 
 
 
 
 
8
 
9
- class Summarizer(abc.ABC):
10
- def __init__(
11
- self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
12
- ) -> None:
13
- raise NotImplementedError
14
 
 
 
 
 
 
 
 
 
 
 
15
  def get_name(self) -> str:
16
  raise NotImplementedError
17
 
@@ -20,12 +29,21 @@ class Summarizer(abc.ABC):
20
  raise NotImplementedError
21
 
22
 
23
- class Tagger(abc.ABC):
24
- def __init__(
25
- self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
26
- ) -> None:
27
- raise NotImplementedError
 
 
 
 
 
 
 
28
 
 
 
29
  def get_name(self) -> str:
30
  raise NotImplementedError
31
 
@@ -34,39 +52,19 @@ class Tagger(abc.ABC):
34
  raise NotImplementedError
35
 
36
 
37
- class HfTransformersSummarizer(Summarizer):
38
- def __init__(
39
- self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
40
- ) -> None:
41
  self.model_name = model_name
42
- self.model = model
43
- self.tokenizer = tokenizer
44
- self.generation_config = generation_config
45
 
46
- self.template = "Summarize the text below in two sentences:\n\n{}"
 
 
 
 
47
 
48
- def __call__(self, x: str) -> str:
49
- text = self.template.format(x)
50
- inputs = self.tokenizer(text, return_tensors="pt")
51
- outputs = self.model.generate(
52
- **inputs, generation_config=self.generation_config
53
- )
54
- output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
55
- assert isinstance(output, str)
56
- return output
57
-
58
- def get_name(self) -> str:
59
- return f"{self.__class__.__name__}({self.model_name})"
60
-
61
-
62
- class HfTransformersTagger(Tagger):
63
- def __init__(
64
- self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
65
- ) -> None:
66
- self.model_name = model_name
67
- self.model = model
68
- self.tokenizer = tokenizer
69
- self.generation_config = generation_config
70
 
71
  self.template = (
72
  "Create a list of tags for the text below. The tags should be high level "
 
1
  import abc
 
 
2
 
3
+ from huggingface_hub import login
4
+ from transformers.tools import TextSummarizationTool
5
+ from transformers import HfAgent
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
7
 
8
+ from gistillery.config import get_config
9
+
10
+
11
+ agent = None
12
 
 
 
 
 
 
13
 
14
+ def get_agent() -> HfAgent:
15
+ global agent
16
+ if agent is None:
17
+ login(get_config().hf_hub_token)
18
+ agent = HfAgent(get_config().hf_agent)
19
+ return agent
20
+
21
+
22
+ class Summarizer(abc.ABC):
23
+ @abc.abstractmethod
24
  def get_name(self) -> str:
25
  raise NotImplementedError
26
 
 
29
  raise NotImplementedError
30
 
31
 
32
+ class HfDefaultSummarizer(Summarizer):
33
+ def __init__(self) -> None:
34
+ self.summarizer = TextSummarizationTool()
35
+
36
+ def get_name(self) -> str:
37
+ return "hf_default"
38
+
39
+ def __call__(self, x: str) -> str:
40
+ summary = self.summarizer(x)
41
+ assert isinstance(summary, str)
42
+ return summary
43
+
44
 
45
+ class Tagger(abc.ABC):
46
+ @abc.abstractmethod
47
  def get_name(self) -> str:
48
  raise NotImplementedError
49
 
 
52
  raise NotImplementedError
53
 
54
 
55
+ class HfDefaultTagger(Tagger):
56
+ def __init__(self, model_name: str = "google/flan-t5-large") -> None:
 
 
57
  self.model_name = model_name
 
 
 
58
 
59
+ config = GenerationConfig.from_pretrained(self.model_name)
60
+ config.max_new_tokens = 50
61
+ config.min_new_tokens = 25
62
+ # increase the temperature to make the model more creative
63
+ config.temperature = 1.5
64
 
65
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
66
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
67
+ self.generation_config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  self.template = (
70
  "Create a list of tags for the text below. The tags should be high level "
src/gistillery/webservice.py CHANGED
@@ -37,8 +37,28 @@ def submit_job(input: RequestInput) -> str:
37
  return f"Submitted job {_id}"
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @app.get("/check_job_status/{_id}")
41
- def check_job_status(_id: str) -> JobStatusResult:
42
  with get_db_cursor() as cursor:
43
  cursor.execute(
44
  "SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
 
37
  return f"Submitted job {_id}"
38
 
39
 
40
+ @app.get("/check_job_status/")
41
+ def check_job_status() -> str:
42
+ with get_db_cursor() as cursor:
43
+ cursor.execute(
44
+ "SELECT entry_id "
45
+ "FROM jobs WHERE status = 'pending' "
46
+ "ORDER BY last_updated ASC"
47
+ )
48
+ result = cursor.fetchall()
49
+
50
+ if not result:
51
+ return "No pending jobs found"
52
+
53
+ entry_ids = [r.entry_id for r in result]
54
+ num_entries = len(entry_ids)
55
+ if len(entry_ids) > 3:
56
+ entry_ids = entry_ids[:3] + ["..."]
57
+ return f"Found {num_entries} pending job(s): {', '.join(entry_ids)}"
58
+
59
+
60
  @app.get("/check_job_status/{_id}")
61
+ def check_job_status_id(_id: str) -> JobStatusResult:
62
  with get_db_cursor() as cursor:
63
  cursor.execute(
64
  "SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
src/gistillery/worker.py CHANGED
@@ -3,9 +3,7 @@ from dataclasses import dataclass
3
 
4
  from gistillery.base import JobInput
5
  from gistillery.db import get_db_cursor
6
- from gistillery.ml import HfTransformersSummarizer, HfTransformersTagger
7
- from gistillery.preprocessing import DefaultUrlProcessor, RawTextProcessor
8
- from gistillery.registry import MlRegistry
9
 
10
  SLEEP_INTERVAL = 5
11
 
@@ -13,7 +11,7 @@ SLEEP_INTERVAL = 5
13
  def check_pending_jobs() -> list[JobInput]:
14
  """Check DB for pending jobs"""
15
  with get_db_cursor() as cursor:
16
- # fetch pending jobs, join authro and content from entries table
17
  query = """
18
  SELECT j.entry_id, e.author, e.source
19
  FROM jobs j
@@ -21,7 +19,7 @@ def check_pending_jobs() -> list[JobInput]:
21
  ON j.entry_id = e.id
22
  WHERE j.status = 'pending'
23
  """
24
- res = list(cursor.execute(query))
25
  return [
26
  JobInput(id=_id, author=author, content=content) for _id, author, content in res
27
  ]
@@ -37,7 +35,7 @@ class JobOutput:
37
  tagger_name: str
38
 
39
 
40
- def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
41
  processor = registry.get_processor(job)
42
  processor_name = processor.get_name()
43
  processed = processor(job)
@@ -79,7 +77,7 @@ def store(job: JobInput, output: JobOutput) -> None:
79
  )
80
 
81
 
82
- def process_job(job: JobInput, registry: MlRegistry) -> None:
83
  tic = time.perf_counter()
84
  print(f"Processing job for (id={job.id[:8]})")
85
 
@@ -105,41 +103,8 @@ def process_job(job: JobInput, registry: MlRegistry) -> None:
105
  print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
106
 
107
 
108
- def load_mlregistry(model_name: str) -> MlRegistry:
109
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
110
-
111
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
112
- tokenizer = AutoTokenizer.from_pretrained(model_name)
113
-
114
- config_summarizer = GenerationConfig.from_pretrained(model_name)
115
- config_summarizer.max_new_tokens = 200
116
- config_summarizer.min_new_tokens = 100
117
- config_summarizer.top_k = 5
118
- config_summarizer.repetition_penalty = 1.5
119
-
120
- config_tagger = GenerationConfig.from_pretrained(model_name)
121
- config_tagger.max_new_tokens = 50
122
- config_tagger.min_new_tokens = 25
123
- # increase the temperature to make the model more creative
124
- config_tagger.temperature = 1.5
125
-
126
- summarizer = HfTransformersSummarizer(
127
- model_name, model, tokenizer, config_summarizer
128
- )
129
- tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
130
-
131
- registry = MlRegistry()
132
- registry.register_processor(DefaultUrlProcessor())
133
- registry.register_processor(RawTextProcessor())
134
- registry.register_summarizer(summarizer)
135
- registry.register_tagger(tagger)
136
-
137
- return registry
138
-
139
-
140
  def main() -> None:
141
- model_name = "google/flan-t5-large"
142
- registry = load_mlregistry(model_name)
143
 
144
  while True:
145
  jobs = check_pending_jobs()
 
3
 
4
  from gistillery.base import JobInput
5
  from gistillery.db import get_db_cursor
6
+ from gistillery.registry import ToolRegistry, get_tool_registry
 
 
7
 
8
  SLEEP_INTERVAL = 5
9
 
 
11
  def check_pending_jobs() -> list[JobInput]:
12
  """Check DB for pending jobs"""
13
  with get_db_cursor() as cursor:
14
+ # fetch pending jobs, join author and content from entries table
15
  query = """
16
  SELECT j.entry_id, e.author, e.source
17
  FROM jobs j
 
19
  ON j.entry_id = e.id
20
  WHERE j.status = 'pending'
21
  """
22
+ res = cursor.execute(query).fetchall()
23
  return [
24
  JobInput(id=_id, author=author, content=content) for _id, author, content in res
25
  ]
 
35
  tagger_name: str
36
 
37
 
38
+ def _process_job(job: JobInput, registry: ToolRegistry) -> JobOutput:
39
  processor = registry.get_processor(job)
40
  processor_name = processor.get_name()
41
  processed = processor(job)
 
77
  )
78
 
79
 
80
+ def process_job(job: JobInput, registry: ToolRegistry) -> None:
81
  tic = time.perf_counter()
82
  print(f"Processing job for (id={job.id[:8]})")
83
 
 
103
  print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def main() -> None:
107
+ registry = get_tool_registry()
 
108
 
109
  while True:
110
  jobs = check_pending_jobs()
tests/test_app.py CHANGED
@@ -35,18 +35,14 @@ class TestWebservice:
35
  return client
36
 
37
  @pytest.fixture
38
- def mlregistry(self):
39
  # use dummy models
40
- from gistillery.ml import Summarizer, Tagger
41
  from gistillery.preprocessing import RawTextProcessor
42
- from gistillery.registry import MlRegistry
43
 
44
  class DummySummarizer(Summarizer):
45
  """Returns the first 10 characters of the input"""
46
-
47
- def __init__(self, *args, **kwargs):
48
- pass
49
-
50
  def get_name(self):
51
  return "dummy summarizer"
52
 
@@ -55,24 +51,20 @@ class TestWebservice:
55
 
56
  class DummyTagger(Tagger):
57
  """Returns the first 3 words of the input"""
58
-
59
- def __init__(self, *args, **kwargs):
60
- pass
61
-
62
  def get_name(self):
63
  return "dummy tagger"
64
 
65
  def __call__(self, x):
66
  return ["#" + word for word in x.split(maxsplit=4)[:3]]
67
 
68
- registry = MlRegistry()
69
  registry.register_processor(RawTextProcessor())
70
 
71
  # arguments don't matter for dummy summarizer and tagger
72
- summarizer = DummySummarizer(None, None, None, None)
73
  registry.register_summarizer(summarizer)
74
 
75
- tagger = DummyTagger(None, None, None, None)
76
  registry.register_tagger(tagger)
77
  return registry
78
 
@@ -128,7 +120,7 @@ class TestWebservice:
128
  }
129
  assert last_updated is None
130
 
131
- def test_submitted_job_failed(self, client, mlregistry, monkeypatch):
132
  # monkeypatch uuid4 to return a known value
133
  job_id = "abc1234"
134
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
@@ -143,7 +135,7 @@ class TestWebservice:
143
  "gistillery.worker._process_job",
144
  lambda job, registry: raise_(RuntimeError("something went wrong")),
145
  )
146
- self.process_jobs(mlregistry)
147
 
148
  resp = client.get(f"/check_job_status/{job_id}")
149
  output = resp.json()
@@ -153,12 +145,12 @@ class TestWebservice:
153
  "status": "failed",
154
  }
155
 
156
- def test_submitted_job_status_done(self, client, mlregistry, monkeypatch):
157
  # monkeypatch uuid4 to return a known value
158
  job_id = "abc1234"
159
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
160
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
161
- self.process_jobs(mlregistry)
162
 
163
  resp = client.get(f"/check_job_status/{job_id}")
164
  output = resp.json()
@@ -169,7 +161,28 @@ class TestWebservice:
169
  }
170
  assert is_roughly_now(last_updated)
171
 
172
- def test_recent_with_entries(self, client, mlregistry):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # submit 2 entries
174
  client.post(
175
  "/submit", json={"author": "maxi", "content": "this is a first test"}
@@ -178,7 +191,7 @@ class TestWebservice:
178
  "/submit",
179
  json={"author": "mini", "content": "this would be something else"},
180
  )
181
- self.process_jobs(mlregistry)
182
  resp = client.get("/recent").json()
183
 
184
  # results are sorted by recency but since dummy models are so fast, the
@@ -196,7 +209,7 @@ class TestWebservice:
196
  assert resp1["summary"] == "this would"
197
  assert resp1["tags"] == sorted(["#this", "#would", "#be"])
198
 
199
- def test_recent_tag_with_entries(self, client, mlregistry):
200
  # submit 2 entries
201
  client.post(
202
  "/submit", json={"author": "maxi", "content": "this is a first test"}
@@ -205,7 +218,7 @@ class TestWebservice:
205
  "/submit",
206
  json={"author": "mini", "content": "this would be something else"},
207
  )
208
- self.process_jobs(mlregistry)
209
 
210
  # the "this" tag is in both entries
211
  resp = client.get("/recent/this").json()
@@ -220,22 +233,22 @@ class TestWebservice:
220
  assert resp0["summary"] == "this would"
221
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
222
 
223
- def test_clear(self, client, cursor, mlregistry):
224
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
225
- self.process_jobs(mlregistry)
226
  assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
227
 
228
  client.get("/clear")
229
  assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
230
 
231
- def test_inputs_stored(self, client, cursor, mlregistry):
232
  client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
233
- self.process_jobs(mlregistry)
234
  rows = cursor.execute("SELECT * FROM inputs").fetchall()
235
  assert len(rows) == 1
236
  assert rows[0].input == "this is a test"
237
 
238
- def test_submit_url(self, client, cursor, mlregistry, monkeypatch):
239
  class MockClient:
240
  """Mock httpx Client, return www.example.com content"""
241
 
@@ -269,7 +282,7 @@ class TestWebservice:
269
  from gistillery.preprocessing import DefaultUrlProcessor
270
 
271
  # register url processor, put it before the default processor
272
- mlregistry.register_processor(DefaultUrlProcessor(), last=False)
273
  client.post(
274
  "/submit",
275
  json={
@@ -277,7 +290,7 @@ class TestWebservice:
277
  "content": "https://en.wikipedia.org/wiki/non-existing-page",
278
  },
279
  )
280
- self.process_jobs(mlregistry)
281
 
282
  rows = cursor.execute("SELECT * FROM inputs").fetchall()
283
  assert len(rows) == 1
 
35
  return client
36
 
37
  @pytest.fixture
38
+ def registry(self):
39
  # use dummy models
40
+ from gistillery.tools import Summarizer, Tagger
41
  from gistillery.preprocessing import RawTextProcessor
42
+ from gistillery.registry import ToolRegistry
43
 
44
  class DummySummarizer(Summarizer):
45
  """Returns the first 10 characters of the input"""
 
 
 
 
46
  def get_name(self):
47
  return "dummy summarizer"
48
 
 
51
 
52
  class DummyTagger(Tagger):
53
  """Returns the first 3 words of the input"""
 
 
 
 
54
  def get_name(self):
55
  return "dummy tagger"
56
 
57
  def __call__(self, x):
58
  return ["#" + word for word in x.split(maxsplit=4)[:3]]
59
 
60
+ registry = ToolRegistry()
61
  registry.register_processor(RawTextProcessor())
62
 
63
  # arguments don't matter for dummy summarizer and tagger
64
+ summarizer = DummySummarizer()
65
  registry.register_summarizer(summarizer)
66
 
67
+ tagger = DummyTagger()
68
  registry.register_tagger(tagger)
69
  return registry
70
 
 
120
  }
121
  assert last_updated is None
122
 
123
+ def test_submitted_job_failed(self, client, registry, monkeypatch):
124
  # monkeypatch uuid4 to return a known value
125
  job_id = "abc1234"
126
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
 
135
  "gistillery.worker._process_job",
136
  lambda job, registry: raise_(RuntimeError("something went wrong")),
137
  )
138
+ self.process_jobs(registry)
139
 
140
  resp = client.get(f"/check_job_status/{job_id}")
141
  output = resp.json()
 
145
  "status": "failed",
146
  }
147
 
148
+ def test_submitted_job_status_done(self, client, registry, monkeypatch):
149
  # monkeypatch uuid4 to return a known value
150
  job_id = "abc1234"
151
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
152
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
153
+ self.process_jobs(registry)
154
 
155
  resp = client.get(f"/check_job_status/{job_id}")
156
  output = resp.json()
 
161
  }
162
  assert is_roughly_now(last_updated)
163
 
164
+ def test_status_pending_jobs(self, client, registry, monkeypatch):
165
+ resp = client.get("/check_job_status/")
166
+ output = resp.json()
167
+ assert output == "No pending jobs found"
168
+
169
+ monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex="abc0"))
170
+ client.post("/submit", json={"author": "ben", "content": "this is a test"})
171
+ resp = client.get("/check_job_status/")
172
+ output = resp.json()
173
+ expected = "Found 1 pending job(s): abc0"
174
+ assert output == expected
175
+
176
+ for i in range(1, 10):
177
+ monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=f"abc{i}"))
178
+ client.post("/submit", json={"author": "ben", "content": "this is a test"})
179
+
180
+ resp = client.get("/check_job_status/")
181
+ output = resp.json()
182
+ expected = "Found 10 pending job(s): abc0, abc1, abc2, ..."
183
+ assert output == expected
184
+
185
+ def test_recent_with_entries(self, client, registry):
186
  # submit 2 entries
187
  client.post(
188
  "/submit", json={"author": "maxi", "content": "this is a first test"}
 
191
  "/submit",
192
  json={"author": "mini", "content": "this would be something else"},
193
  )
194
+ self.process_jobs(registry)
195
  resp = client.get("/recent").json()
196
 
197
  # results are sorted by recency but since dummy models are so fast, the
 
209
  assert resp1["summary"] == "this would"
210
  assert resp1["tags"] == sorted(["#this", "#would", "#be"])
211
 
212
+ def test_recent_tag_with_entries(self, client, registry):
213
  # submit 2 entries
214
  client.post(
215
  "/submit", json={"author": "maxi", "content": "this is a first test"}
 
218
  "/submit",
219
  json={"author": "mini", "content": "this would be something else"},
220
  )
221
+ self.process_jobs(registry)
222
 
223
  # the "this" tag is in both entries
224
  resp = client.get("/recent/this").json()
 
233
  assert resp0["summary"] == "this would"
234
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
235
 
236
+ def test_clear(self, client, cursor, registry):
237
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
238
+ self.process_jobs(registry)
239
  assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
240
 
241
  client.get("/clear")
242
  assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
243
 
244
+ def test_inputs_stored(self, client, cursor, registry):
245
  client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
246
+ self.process_jobs(registry)
247
  rows = cursor.execute("SELECT * FROM inputs").fetchall()
248
  assert len(rows) == 1
249
  assert rows[0].input == "this is a test"
250
 
251
+ def test_submit_url(self, client, cursor, registry, monkeypatch):
252
  class MockClient:
253
  """Mock httpx Client, return www.example.com content"""
254
 
 
282
  from gistillery.preprocessing import DefaultUrlProcessor
283
 
284
  # register url processor, put it before the default processor
285
+ registry.register_processor(DefaultUrlProcessor(), last=False)
286
  client.post(
287
  "/submit",
288
  json={
 
290
  "content": "https://en.wikipedia.org/wiki/non-existing-page",
291
  },
292
  )
293
+ self.process_jobs(registry)
294
 
295
  rows = cursor.execute("SELECT * FROM inputs").fetchall()
296
  assert len(rows) == 1