juanpablo4l commited on
Commit
d0fad25
1 Parent(s): 857aac4

Added NLG models

Browse files
Files changed (2) hide show
  1. app.py +17 -3
  2. models.py +57 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from models import MODELS, PIPELINES
4
 
5
 
6
  def predict(text: str, model_name: str) -> str:
@@ -9,16 +9,30 @@ def predict(text: str, model_name: str) -> str:
9
 
10
  with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
11
  gr.Markdown("Dialogue State Tracking Modules")
12
- for model_name in MODELS:
13
  with gr.Row():
14
  gr.Markdown(f"## {model_name}")
15
  model_name_component = gr.Textbox(value=model_name, visible=False)
16
  with gr.Row():
17
- text_input = gr.Textbox(label="Input Text", value=MODELS[model_name]["default_input"])
18
  output = gr.Textbox(label="Slot Value", value="")
19
  with gr.Row():
20
  predict_button = gr.Button("Predict")
21
  predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  demo.queue(concurrency_count=3)
24
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from models import DST_MODELS, NLG_MODELS, PIPELINES
4
 
5
 
6
  def predict(text: str, model_name: str) -> str:
 
9
 
10
  with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
11
  gr.Markdown("Dialogue State Tracking Modules")
12
+ for model_name in DST_MODELS:
13
  with gr.Row():
14
  gr.Markdown(f"## {model_name}")
15
  model_name_component = gr.Textbox(value=model_name, visible=False)
16
  with gr.Row():
17
+ text_input = gr.Textbox(label="Input Text", value=DST_MODELS[model_name]["default_input"])
18
  output = gr.Textbox(label="Slot Value", value="")
19
  with gr.Row():
20
  predict_button = gr.Button("Predict")
21
  predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
22
 
23
+
24
+ gr.Markdown("Natural Language Generation / Paraphrasing Modules")
25
+ for model_name in NLG_MODELS:
26
+ with gr.Row():
27
+ gr.Markdown(f"## {model_name}")
28
+ model_name_component = gr.Textbox(value=model_name, visible=False)
29
+ with gr.Row():
30
+ text_input = gr.Textbox(label="Input Text", value=NLG_MODELS[model_name]["default_input"])
31
+ output = gr.Textbox(label="Slot Value", value="")
32
+ with gr.Row():
33
+ predict_button = gr.Button("Predict")
34
+ predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
35
+
36
+
37
  demo.queue(concurrency_count=3)
38
  demo.launch()
models.py CHANGED
@@ -2,11 +2,12 @@ import os
2
  from typing import Any, Dict
3
 
4
  from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
5
- pipeline)
6
 
7
  auth_token = os.environ.get("CLARIN_KNEXT")
8
 
9
- DEFAULT_INPUTS: Dict[str, str] = {
 
10
  "polish": (
11
  "[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
12
  "[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
@@ -19,47 +20,91 @@ DEFAULT_INPUTS: Dict[str, str] = {
19
  ),
20
  }
21
 
22
- MODELS: Dict[str, Dict[str, Any]] = {
 
23
  "plt5-small": {
24
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
25
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
26
- "default_input": DEFAULT_INPUTS["polish"],
27
  },
28
  "plt5-base": {
29
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
30
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
31
- "default_input": DEFAULT_INPUTS["polish"],
32
  },
33
  "plt5-base-poquad-dst-v2": {
34
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
35
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
36
- "default_input": DEFAULT_INPUTS["polish"],
37
  },
38
  "t5-small": {
39
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
40
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
41
- "default_input": DEFAULT_INPUTS["english"],
42
  },
43
  "t5-base": {
44
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
45
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
46
- "default_input": DEFAULT_INPUTS["english"],
47
  },
48
  "flant5-small [EN/PL]": {
49
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
50
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
51
- "default_input": DEFAULT_INPUTS["english"],
52
  },
53
  "flant5-base [EN/PL]": {
54
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
55
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
56
- "default_input": DEFAULT_INPUTS["english"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  },
58
  }
59
 
60
  PIPELINES: Dict[str, Pipeline] = {
61
  model_name: pipeline(
62
- "text2text-generation", model=MODELS[model_name]["model"], tokenizer=MODELS[model_name]["tokenizer"]
63
  )
64
- for model_name in MODELS
65
  }
 
2
  from typing import Any, Dict
3
 
4
  from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
5
+ pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer)
6
 
7
  auth_token = os.environ.get("CLARIN_KNEXT")
8
 
9
+
10
+ DEFAULT_DST_INPUTS: Dict[str, str] = {
11
  "polish": (
12
  "[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
13
  "[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
 
20
  ),
21
  }
22
 
23
+
24
+ DST_MODELS: Dict[str, Dict[str, Any]] = {
25
  "plt5-small": {
26
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
27
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
28
+ "default_input": DEFAULT_DST_INPUTS["polish"],
29
  },
30
  "plt5-base": {
31
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
32
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
33
+ "default_input": DEFAULT_DST_INPUTS["polish"],
34
  },
35
  "plt5-base-poquad-dst-v2": {
36
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
37
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
38
+ "default_input": DEFAULT_DST_INPUTS["polish"],
39
  },
40
  "t5-small": {
41
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
42
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
43
+ "default_input": DEFAULT_DST_INPUTS["english"],
44
  },
45
  "t5-base": {
46
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
47
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
48
+ "default_input": DEFAULT_DST_INPUTS["english"],
49
  },
50
  "flant5-small [EN/PL]": {
51
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
52
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
53
+ "default_input": DEFAULT_DST_INPUTS["english"],
54
  },
55
  "flant5-base [EN/PL]": {
56
  "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
57
  "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
58
+ "default_input": DEFAULT_DST_INPUTS["english"],
59
+ },
60
+ }
61
+
62
+
63
+ DEFAULT_ENCODER_DECODER_INPUT_EN = "The alarm is set for 6 am. The alarm's name is name \"Get up\"."
64
+ DEFAULT_DECODER_ONLY_INPUT_EN = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_EN}[SEP]"
65
+ DEFAULT_ENCODER_DECODER_INPUT_PL = "Alarm jest o godzinie 6 rano. Alarm ma nazwę \"Obudź się\"."
66
+ DEFAULT_DECODER_ONLY_INPUT_PL = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_PL}[SEP]"
67
+
68
+
69
+
70
+ NLG_MODELS: Dict[str, Dict[str, Any]] = {
71
+ # English
72
+ "t5-large": {
73
+ "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
74
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
75
+ "default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
76
+ },
77
+ "en-mt5-large": {
78
+ "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
79
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
80
+ "default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
81
+ },
82
+ "gpt2": {
83
+ "model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
84
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
85
+ "default_input": DEFAULT_DECODER_ONLY_INPUT_EN,
86
+ },
87
+
88
+ "pt5-large": {
89
+ "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
90
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
91
+ "default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
92
+ },
93
+ "pl-mt5-large": {
94
+ "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
95
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
96
+ "default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
97
+ },
98
+ "polish-gpt2": {
99
+ "model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
100
+ "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
101
+ "default_input": DEFAULT_DECODER_ONLY_INPUT_PL,
102
  },
103
  }
104
 
105
  PIPELINES: Dict[str, Pipeline] = {
106
  model_name: pipeline(
107
+ "text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
108
  )
109
+ for model_name in DST_MODELS
110
  }