michalilski commited on
Commit
dc7ce01
1 Parent(s): 6ef6e5f

nlg models removal

Browse files
Files changed (2) hide show
  1. app.py +2 -16
  2. models.py +1 -43
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import gradio as gr
2
 
3
- from models import DST_MODELS, NLG_MODELS, PIPELINES
4
 
5
 
6
  def predict(text: str, model_name: str) -> str:
7
  return PIPELINES[model_name](text)
8
 
9
 
10
- with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
11
- gr.Markdown("Dialogue State Tracking Modules")
12
  for model_name in DST_MODELS:
13
  with gr.Row():
14
  gr.Markdown(f"## {model_name}")
@@ -21,18 +20,5 @@ with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
21
  predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
22
 
23
 
24
- gr.Markdown("Natural Language Generation / Paraphrasing Modules")
25
- for model_name in NLG_MODELS:
26
- with gr.Row():
27
- gr.Markdown(f"## {model_name}")
28
- model_name_component = gr.Textbox(value=model_name, visible=False)
29
- with gr.Row():
30
- text_input = gr.Textbox(label="Input Text", value=NLG_MODELS[model_name]["default_input"])
31
- output = gr.Textbox(label="Slot Value", value="")
32
- with gr.Row():
33
- predict_button = gr.Button("Predict")
34
- predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
35
-
36
-
37
  demo.queue(concurrency_count=3)
38
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from models import DST_MODELS, PIPELINES
4
 
5
 
6
  def predict(text: str, model_name: str) -> str:
7
  return PIPELINES[model_name](text)
8
 
9
 
10
+ with gr.Blocks(title="CLARIN-PL DST Modules") as demo:
 
11
  for model_name in DST_MODELS:
12
  with gr.Row():
13
  gr.Markdown(f"## {model_name}")
 
20
  predict_button.click(fn=predict, inputs=[text_input, model_name_component], outputs=output)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  demo.queue(concurrency_count=3)
24
  demo.launch()
models.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from typing import Any, Dict
3
 
4
  from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
5
- pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer)
6
 
7
  auth_token = os.environ.get("CLARIN_KNEXT")
8
 
@@ -60,48 +60,6 @@ DST_MODELS: Dict[str, Dict[str, Any]] = {
60
  }
61
 
62
 
63
- DEFAULT_ENCODER_DECODER_INPUT_EN = "The alarm is set for 6 am. The alarm's name is name \"Get up\"."
64
- DEFAULT_DECODER_ONLY_INPUT_EN = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_EN}[SEP]"
65
- DEFAULT_ENCODER_DECODER_INPUT_PL = "Alarm jest o godzinie 6 rano. Alarm ma nazwę \"Obudź się\"."
66
- DEFAULT_DECODER_ONLY_INPUT_PL = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_PL}[SEP]"
67
-
68
-
69
-
70
- NLG_MODELS: Dict[str, Dict[str, Any]] = {
71
- # English
72
- "t5-large": {
73
- "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
74
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token),
75
- "default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
76
- },
77
- "en-mt5-large": {
78
- "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
79
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token),
80
- "default_input": DEFAULT_ENCODER_DECODER_INPUT_EN,
81
- },
82
- "gpt2": {
83
- "model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
84
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token),
85
- "default_input": DEFAULT_DECODER_ONLY_INPUT_EN,
86
- },
87
-
88
- "pt5-large": {
89
- "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
90
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token),
91
- "default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
92
- },
93
- "pl-mt5-large": {
94
- "model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
95
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token),
96
- "default_input": DEFAULT_ENCODER_DECODER_INPUT_PL,
97
- },
98
- "polish-gpt2": {
99
- "model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
100
- "tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token),
101
- "default_input": DEFAULT_DECODER_ONLY_INPUT_PL,
102
- },
103
- }
104
-
105
  PIPELINES: Dict[str, Pipeline] = {
106
  model_name: pipeline(
107
  "text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
 
2
  from typing import Any, Dict
3
 
4
  from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
5
+ pipeline)
6
 
7
  auth_token = os.environ.get("CLARIN_KNEXT")
8
 
 
60
  }
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  PIPELINES: Dict[str, Pipeline] = {
64
  model_name: pipeline(
65
  "text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]