import os from typing import Any, Dict from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer, pipeline) auth_token = os.environ.get("CLARIN_KNEXT") DEFAULT_DST_INPUTS: Dict[str, str] = { "polish": ( "[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. " "[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji " "[Atrybut] Czas: Wstępny czas rezerwacji restauracji" ), "english": ( "[U] I want to book a table for 4 people on Friday, 6:30 pm. " "[Domain] Restaurants: A popular restaurant search and reservation service " "[Slot] Time: Tentative time of restaurant reservation" ), } DST_MODELS: Dict[str, Dict[str, Any]] = { "plt5-small": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["polish"], }, "plt5-base": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["polish"], }, "plt5-base-poquad-dst-v2": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["polish"], }, "t5-small": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["english"], }, "t5-base": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["english"], }, "flant5-small [EN/PL]": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["english"], }, "flant5-base [EN/PL]": { "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token), "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token), "default_input": DEFAULT_DST_INPUTS["english"], }, } PIPELINES: Dict[str, Pipeline] = { model_name: pipeline( "text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"] ) for model_name in DST_MODELS }