eubinecto commited on
Commit
47e4017
1 Parent(s): cffca27

[#9] `fetch_pipeline` has been added. Fixed the bug where <pad> tokens will appear in the final output.

Browse files
Files changed (4) hide show
  1. idiomify/fetchers.py +15 -0
  2. idiomify/pipeline.py +5 -2
  3. main_deploy.py +5 -11
  4. main_infer.py +4 -12
idiomify/fetchers.py CHANGED
@@ -8,6 +8,7 @@ from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, idiomifie
8
  from idiomify.urls import PIE_URL
9
  from transformers import AutoModelForSeq2SeqLM, AutoConfig, BartTokenizer
10
  from idiomify.models import Idiomifier
 
11
 
12
 
13
  # --- from the web --- #
@@ -75,6 +76,20 @@ def fetch_tokenizer(ver: str, run: Run = None) -> BartTokenizer:
75
  return tokenizer
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def fetch_config() -> dict:
79
  with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
80
  return yaml.safe_load(fh)
 
8
  from idiomify.urls import PIE_URL
9
  from transformers import AutoModelForSeq2SeqLM, AutoConfig, BartTokenizer
10
  from idiomify.models import Idiomifier
11
+ from idiomify.pipeline import Pipeline
12
 
13
 
14
  # --- from the web --- #
 
76
  return tokenizer
77
 
78
 
79
+ def fetch_pipeline() -> Pipeline:
80
+ """
81
+ fetch a pipeline of the version stated in config.yaml
82
+ """
83
+ config = fetch_config()['idiomifier']
84
+ model = fetch_idiomifier(config['ver'])
85
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'])
86
+ idioms = fetch_idioms(config['idioms_ver'])
87
+ model.eval() # this is crucial to obtain consistent results
88
+ pipeline = Pipeline(model, tokenizer, idioms)
89
+ return pipeline
90
+
91
+
92
+ # --- from local --- #
93
  def fetch_config() -> dict:
94
  with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
95
  return yaml.safe_load(fh)
idiomify/pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  from typing import List
3
  from transformers import BartTokenizer
4
  from idiomify.builders import SourcesBuilder
@@ -7,9 +8,10 @@ from idiomify.models import Idiomifier
7
 
8
  class Pipeline:
9
 
10
- def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
11
  self.model = model
12
  self.builder = SourcesBuilder(tokenizer)
 
13
 
14
  def __call__(self, sents: List[str], max_length=100) -> List[str]:
15
  srcs = self.builder(literal2idiomatic=[(sent, "") for sent in sents])
@@ -19,9 +21,10 @@ class Pipeline:
19
  decoder_start_token_id=self.model.hparams['bos_token_id'],
20
  max_length=max_length,
21
  ) # -> (N, L_t)
 
22
  tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
23
  tgts = [
24
- re.sub(r"<s>|</s>", "", tgt)
25
  for tgt in tgts
26
  ]
27
  return tgts
 
1
  import re
2
+ import pandas as pd
3
  from typing import List
4
  from transformers import BartTokenizer
5
  from idiomify.builders import SourcesBuilder
 
8
 
9
  class Pipeline:
10
 
11
+ def __init__(self, model: Idiomifier, tokenizer: BartTokenizer, idioms: pd.DataFrame):
12
  self.model = model
13
  self.builder = SourcesBuilder(tokenizer)
14
+ self.idioms = idioms
15
 
16
  def __call__(self, sents: List[str], max_length=100) -> List[str]:
17
  srcs = self.builder(literal2idiomatic=[(sent, "") for sent in sents])
 
21
  decoder_start_token_id=self.model.hparams['bos_token_id'],
22
  max_length=max_length,
23
  ) # -> (N, L_t)
24
+ # we don't skip special tokens because we have to keep <idiom> & </idiom> for highlighting idioms.
25
  tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
26
  tgts = [
27
+ re.sub(r"<s>|</s>|<pad>", "", tgt)
28
  for tgt in tgts
29
  ]
30
  return tgts
main_deploy.py CHANGED
@@ -3,30 +3,24 @@ we deploy the pipeline via streamlit.
3
  """
4
  import re
5
  import streamlit as st
6
- from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms, fetch_tokenizer
7
  from idiomify.pipeline import Pipeline
8
 
9
 
10
  @st.cache(allow_output_mutation=True)
11
- def fetch_resources() -> tuple:
12
- config = fetch_config()['idiomifier']
13
- model = fetch_idiomifier(config['ver'])
14
- tokenizer = fetch_tokenizer(config['tokenizer_ver'])
15
- idioms = fetch_idioms(config['idioms_ver'])
16
- return config, model, tokenizer, idioms
17
 
18
 
19
  def main():
20
  # fetch a pre-trained model
21
- config, model, tokenizer, idioms = fetch_resources()
22
- model.eval()
23
- pipeline = Pipeline(model, tokenizer)
24
  st.title("Idiomify Demo")
25
  text = st.text_area("Type sentences here",
26
  value="Just remember that there will always be a hope even when things look hopeless")
27
  with st.sidebar:
28
  st.subheader("Supported idioms")
29
- idioms = [row["Idiom"] for _, row in idioms.iterrows()]
30
  st.write(" / ".join(idioms))
31
 
32
  if st.button(label="Idiomify"):
 
3
  """
4
  import re
5
  import streamlit as st
6
+ from idiomify.fetchers import fetch_pipeline
7
  from idiomify.pipeline import Pipeline
8
 
9
 
10
  @st.cache(allow_output_mutation=True)
11
+ def cache_pipeline() -> Pipeline:
12
+ return fetch_pipeline()
 
 
 
 
13
 
14
 
15
  def main():
16
  # fetch a pre-trained model
17
+ pipeline = cache_pipeline()
 
 
18
  st.title("Idiomify Demo")
19
  text = st.text_area("Type sentences here",
20
  value="Just remember that there will always be a hope even when things look hopeless")
21
  with st.sidebar:
22
  st.subheader("Supported idioms")
23
+ idioms = [row["Idiom"] for _, row in pipeline.idioms.iterrows()]
24
  st.write(" / ".join(idioms))
25
 
26
  if st.button(label="Idiomify"):
main_infer.py CHANGED
@@ -2,9 +2,7 @@
2
  This is for just a simple sanity check on the inference.
3
  """
4
  import argparse
5
- from idiomify.pipeline import Pipeline
6
- from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
7
- from transformers import BartTokenizer
8
 
9
 
10
  def main():
@@ -12,15 +10,9 @@ def main():
12
  parser.add_argument("--sent", type=str,
13
  default="Just remember that there will always be a hope even when things look hopeless")
14
  args = parser.parse_args()
15
- config = fetch_config()['idiomifier']
16
- config.update(vars(args))
17
- model = fetch_idiomifier(config['ver'])
18
- tokenizer = fetch_tokenizer(config['tokenizer_ver'])
19
- model.eval() # this is crucial
20
- pipeline = Pipeline(model, tokenizer)
21
- src = config['sent']
22
- tgts = pipeline(sents=[src])
23
- print(src, "\n->", tgts[0])
24
 
25
 
26
  if __name__ == '__main__':
 
2
  This is for just a simple sanity check on the inference.
3
  """
4
  import argparse
5
+ from idiomify.fetchers import fetch_pipeline
 
 
6
 
7
 
8
  def main():
 
10
  parser.add_argument("--sent", type=str,
11
  default="Just remember that there will always be a hope even when things look hopeless")
12
  args = parser.parse_args()
13
+ pipeline = fetch_pipeline()
14
+ tgts = pipeline(sents=[args.sent])
15
+ print(args.sent, "\n->", tgts[0])
 
 
 
 
 
 
16
 
17
 
18
  if __name__ == '__main__':