[#9] idiomifier:m-1-3 is ready. main_deploy.py is updated accordingly
Browse files- explore/explore_bart_tokenizer_decode_idiom_special_tokens.py +14 -0
- idiomify/pipeline.py +6 -1
- main_deploy.py +9 -6
- main_eval.py +2 -2
- main_infer.py +5 -6
explore/explore_bart_tokenizer_decode_idiom_special_tokens.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from idiomify.fetchers import fetch_tokenizer
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
tokenizer = fetch_tokenizer("t-1-1")
|
6 |
+
sent = "There will always be a <idiom> silver lining </idiom> even when things look pitch black"
|
7 |
+
ids = tokenizer(sent)['input_ids']
|
8 |
+
print(ids)
|
9 |
+
decoded = tokenizer.decode(ids)
|
10 |
+
print(decoded)
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
main()
|
idiomify/pipeline.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import List
|
2 |
from transformers import BartTokenizer
|
3 |
from idiomify.builders import SourcesBuilder
|
@@ -18,5 +19,9 @@ class Pipeline:
|
|
18 |
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
19 |
max_length=max_length,
|
20 |
) # -> (N, L_t)
|
21 |
-
tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=
|
|
|
|
|
|
|
|
|
22 |
return tgts
|
|
|
1 |
+
import re
|
2 |
from typing import List
|
3 |
from transformers import BartTokenizer
|
4 |
from idiomify.builders import SourcesBuilder
|
|
|
19 |
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
20 |
max_length=max_length,
|
21 |
) # -> (N, L_t)
|
22 |
+
tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
|
23 |
+
tgts = [
|
24 |
+
re.sub(r"<s>|</s>", "", tgt)
|
25 |
+
for tgt in tgts
|
26 |
+
]
|
27 |
return tgts
|
main_deploy.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
"""
|
2 |
we deploy the pipeline via streamlit.
|
3 |
"""
|
|
|
4 |
import streamlit as st
|
5 |
-
from
|
6 |
-
from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms
|
7 |
from idiomify.pipeline import Pipeline
|
8 |
|
9 |
|
@@ -11,7 +11,7 @@ from idiomify.pipeline import Pipeline
|
|
11 |
def fetch_resources() -> tuple:
|
12 |
config = fetch_config()['idiomifier']
|
13 |
model = fetch_idiomifier(config['ver'])
|
14 |
-
tokenizer =
|
15 |
idioms = fetch_idioms(config['idioms_ver'])
|
16 |
return config, model, tokenizer, idioms
|
17 |
|
@@ -23,17 +23,20 @@ def main():
|
|
23 |
pipeline = Pipeline(model, tokenizer)
|
24 |
st.title("Idiomify Demo")
|
25 |
text = st.text_area("Type sentences here",
|
26 |
-
value="Just remember there will always be a hope even when things look
|
27 |
with st.sidebar:
|
28 |
st.subheader("Supported idioms")
|
|
|
29 |
st.write(" / ".join(idioms))
|
30 |
|
31 |
if st.button(label="Idiomify"):
|
32 |
with st.spinner("Please wait..."):
|
33 |
sents = [sent for sent in text.split(".") if sent]
|
34 |
-
|
35 |
# highlight the rule & honorifics that were applied
|
36 |
-
|
|
|
|
|
37 |
|
38 |
|
39 |
if __name__ == '__main__':
|
|
|
1 |
"""
|
2 |
we deploy the pipeline via streamlit.
|
3 |
"""
|
4 |
+
import re
|
5 |
import streamlit as st
|
6 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms, fetch_tokenizer
|
|
|
7 |
from idiomify.pipeline import Pipeline
|
8 |
|
9 |
|
|
|
11 |
def fetch_resources() -> tuple:
|
12 |
config = fetch_config()['idiomifier']
|
13 |
model = fetch_idiomifier(config['ver'])
|
14 |
+
tokenizer = fetch_tokenizer(config['tokenizer_ver'])
|
15 |
idioms = fetch_idioms(config['idioms_ver'])
|
16 |
return config, model, tokenizer, idioms
|
17 |
|
|
|
23 |
pipeline = Pipeline(model, tokenizer)
|
24 |
st.title("Idiomify Demo")
|
25 |
text = st.text_area("Type sentences here",
|
26 |
+
value="Just remember that there will always be a hope even when things look hopeless")
|
27 |
with st.sidebar:
|
28 |
st.subheader("Supported idioms")
|
29 |
+
idioms = [row["Idiom"] for _, row in idioms.iterrows()]
|
30 |
st.write(" / ".join(idioms))
|
31 |
|
32 |
if st.button(label="Idiomify"):
|
33 |
with st.spinner("Please wait..."):
|
34 |
sents = [sent for sent in text.split(".") if sent]
|
35 |
+
preds = pipeline(sents, max_length=200)
|
36 |
# highlight the rule & honorifics that were applied
|
37 |
+
preds = [re.sub(r"<idiom>|</idiom>", "`", pred)
|
38 |
+
for pred in preds]
|
39 |
+
st.markdown(". ".join(preds))
|
40 |
|
41 |
|
42 |
if __name__ == '__main__':
|
main_eval.py
CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from transformers import BartTokenizer
|
8 |
from idiomify.datamodules import IdiomifyDataModule
|
9 |
-
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
10 |
from idiomify.paths import ROOT_DIR
|
11 |
|
12 |
|
@@ -17,10 +17,10 @@ def main():
|
|
17 |
args = parser.parse_args()
|
18 |
config = fetch_config()['idiomifier']
|
19 |
config.update(vars(args))
|
20 |
-
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
# prepare the datamodule
|
22 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
23 |
model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
|
|
|
24 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
logger = WandbLogger(log_model=False)
|
26 |
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from transformers import BartTokenizer
|
8 |
from idiomify.datamodules import IdiomifyDataModule
|
9 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
|
10 |
from idiomify.paths import ROOT_DIR
|
11 |
|
12 |
|
|
|
17 |
args = parser.parse_args()
|
18 |
config = fetch_config()['idiomifier']
|
19 |
config.update(vars(args))
|
|
|
20 |
# prepare the datamodule
|
21 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
22 |
model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
|
23 |
+
tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
|
24 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
logger = WandbLogger(log_model=False)
|
26 |
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
main_infer.py
CHANGED
@@ -3,25 +3,24 @@ This is for just a simple sanity check on the inference.
|
|
3 |
"""
|
4 |
import argparse
|
5 |
from idiomify.pipeline import Pipeline
|
6 |
-
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
7 |
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
def main():
|
11 |
parser = argparse.ArgumentParser()
|
12 |
parser.add_argument("--sent", type=str,
|
13 |
-
default="
|
14 |
-
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
15 |
args = parser.parse_args()
|
16 |
config = fetch_config()['idiomifier']
|
17 |
config.update(vars(args))
|
18 |
model = fetch_idiomifier(config['ver'])
|
|
|
19 |
model.eval() # this is crucial
|
20 |
-
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
pipeline = Pipeline(model, tokenizer)
|
22 |
src = config['sent']
|
23 |
-
|
24 |
-
print(src, "\n->",
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
3 |
"""
|
4 |
import argparse
|
5 |
from idiomify.pipeline import Pipeline
|
6 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
|
7 |
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
def main():
|
11 |
parser = argparse.ArgumentParser()
|
12 |
parser.add_argument("--sent", type=str,
|
13 |
+
default="Just remember that there will always be a hope even when things look hopeless")
|
|
|
14 |
args = parser.parse_args()
|
15 |
config = fetch_config()['idiomifier']
|
16 |
config.update(vars(args))
|
17 |
model = fetch_idiomifier(config['ver'])
|
18 |
+
tokenizer = fetch_tokenizer(config['tokenizer_ver'])
|
19 |
model.eval() # this is crucial
|
|
|
20 |
pipeline = Pipeline(model, tokenizer)
|
21 |
src = config['sent']
|
22 |
+
tgts = pipeline(sents=[src])
|
23 |
+
print(src, "\n->", tgts[0])
|
24 |
|
25 |
|
26 |
if __name__ == '__main__':
|