Spaces:
Runtime error
Runtime error
import os | |
from typing import Any, Dict | |
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer, | |
pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer) | |
auth_token = os.environ.get("CLARIN_KNEXT") | |
DEFAULT_DST_INPUTS: Dict[str, str] = { | |
"polish": ( | |
"[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. " | |
"[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji " | |
"[Atrybut] Czas: Wstępny czas rezerwacji restauracji" | |
), | |
"english": ( | |
"[U] I want to book a table for 4 people on Friday, 6:30 pm. " | |
"[Domain] Restaurants: A popular restaurant search and reservation service " | |
"[Slot] Time: Tentative time of restaurant reservation" | |
), | |
} | |
DST_MODELS: Dict[str, Dict[str, Any]] = { | |
"plt5-small": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["polish"], | |
}, | |
"plt5-base": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["polish"], | |
}, | |
"plt5-base-poquad-dst-v2": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["polish"], | |
}, | |
"t5-small": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["english"], | |
}, | |
"t5-base": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["english"], | |
}, | |
"flant5-small [EN/PL]": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["english"], | |
}, | |
"flant5-base [EN/PL]": { | |
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token), | |
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token), | |
"default_input": DEFAULT_DST_INPUTS["english"], | |
}, | |
} | |
DEFAULT_ENCODER_DECODER_INPUT_EN = "The alarm is set for 6 am. The alarm's name is name \"Get up\"." | |
DEFAULT_DECODER_ONLY_INPUT_EN = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_EN}[SEP]" | |
DEFAULT_ENCODER_DECODER_INPUT_PL = "Alarm jest o godzinie 6 rano. Alarm ma nazwę \"Obudź się\"." | |
DEFAULT_DECODER_ONLY_INPUT_PL = f"[BOS]{DEFAULT_ENCODER_DECODER_INPUT_PL}[SEP]" | |
NLG_MODELS: Dict[str, Dict[str, Any]] = { | |
# English | |
"t5-large": { | |
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-t5-large", use_auth_token=auth_token), | |
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN, | |
}, | |
"en-mt5-large": { | |
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-en-mt5-large", use_auth_token=auth_token), | |
"default_input": DEFAULT_ENCODER_DECODER_INPUT_EN, | |
}, | |
"gpt2": { | |
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-gpt2", use_auth_token=auth_token), | |
"default_input": DEFAULT_DECODER_ONLY_INPUT_EN, | |
}, | |
"pt5-large": { | |
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pt5-large", use_auth_token=auth_token), | |
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL, | |
}, | |
"pl-mt5-large": { | |
"model": AutoModelForSeq2SeqLM.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-pl-mt5-large", use_auth_token=auth_token), | |
"default_input": DEFAULT_ENCODER_DECODER_INPUT_PL, | |
}, | |
"polish-gpt2": { | |
"model": AutoModelForCausalLM.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token), | |
"tokenizer": AutoTokenizer.from_pretrained("clarin-knext/utterance-rewriting-polish-gpt2", use_auth_token=auth_token), | |
"default_input": DEFAULT_DECODER_ONLY_INPUT_PL, | |
}, | |
} | |
PIPELINES: Dict[str, Pipeline] = { | |
model_name: pipeline( | |
"text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"] | |
) | |
for model_name in DST_MODELS | |
} | |