michalilski commited on
Commit
00a31fe
1 Parent(s): 5194cc9

plt5 support

Browse files
Files changed (3) hide show
  1. app.py +21 -0
  2. models.py +40 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from models import MODELS, PIPELINES
4
+
5
+
6
+ def predict(text: str, model_name: str) -> str:
7
+ return {"text": PIPELINES[model_name](text)}
8
+
9
+
10
+ with gr.Blocks(title="CLARIN-PL Dialogue System Modules") as demo:
11
+ gr.Markdown("Dialogue State Tracking Modules")
12
+ for model_name in MODELS:
13
+ with gr.Row():
14
+ gr.Markdown(f"## {model_name}")
15
+ model_name_component = gr.Textbox(value=model_name, visible=False)
16
+ with gr.Row():
17
+ text_input = gr.Textbox(label="Input Text", value=MODELS[model_name]["default_input"])
18
+ gr.Interface(fn=predict, inputs=[text_input, model_name_component], outputs="text")
19
+
20
+ demo.queue(concurrency_count=3)
21
+ demo.launch()
models.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
5
+ pipeline)
6
+
7
+ auth_token = os.environ.get("CLARIN_KNEXT")
8
+
9
+ DEFAULT_INPUTS: Dict[str, str] = {
10
+ "polish": (
11
+ "[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18.30. "
12
+ "[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
13
+ "[Atrybut] Czas: Wstępny czas rezerwacji restauracji"
14
+ ),
15
+ "english": (
16
+ "[U] I want to book a table for 4 people on Friday, 6.30 pm. "
17
+ "[Domain] Restaurants: A popular restaurant search and reservation service "
18
+ "[Slot] Time: Tentative time of restaurant reservation"
19
+ ),
20
+ }
21
+
22
+ MODELS: Dict[str, Dict[str, Any]] = {
23
+ "plt5-small": {
24
+ "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
25
+ "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
26
+ "default_input": DEFAULT_INPUTS["polish"],
27
+ },
28
+ "plt5-base": {
29
+ "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
30
+ "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
31
+ "default_input": DEFAULT_INPUTS["polish"],
32
+ },
33
+ }
34
+
35
+ PIPELINES: Dict[str, Pipeline] = {
36
+ model_name: pipeline(
37
+ "text2text-generation", model=MODELS[model_name]["model"], tokenizer=MODELS[model_name]["tokenizer"]
38
+ )
39
+ for model_name in MODELS
40
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentencepiece