translation-report / config.py
Alex Cabrera
config
8ceca60
raw
history blame contribute delete
No virus
4.08 kB
"""Config for analyzing GPT-MT."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from zeno_build.evaluation.text_features.capitalization import input_capital_char_ratio
from zeno_build.evaluation.text_features.exact_match import avg_exact_match, exact_match
from zeno_build.evaluation.text_features.frequency import output_max_word_freq
from zeno_build.evaluation.text_features.length import (
doc_context_length,
input_length,
label_length,
output_length,
)
from zeno_build.evaluation.text_metrics.critique import (
avg_bert_score,
avg_chrf,
avg_comet,
avg_length_ratio,
bert_score,
chrf,
comet,
length_ratio,
)
from zeno_build.experiments import search_space
from modeling import remove_leading_language
lang_pairs: dict[str, list[str]] = {
# All language pairs used in any experiment
"all_lang_pairs": [
"csen",
"deen",
"defr",
"encs",
"ende",
"enha",
"enis",
"enja",
"enru",
"enuk",
"enzh",
"frde",
"haen",
"isen",
"jaen",
"ruen",
"uken",
"zhen",
],
# Language pairs used in the experiments on a limited number of language pairs
"limited_lang_pairs": [
"deen",
"defr",
"ende",
"enru",
"enzh",
"frde",
"ruen",
"zhen",
],
}
# The search space for the main experiments
main_space = search_space.CombinatorialSearchSpace(
{
"lang_pairs": search_space.Constant("all_lang_pairs"),
"model_preset": search_space.Categorical(
[
"text-davinci-003-zeroshot",
"text-davinci-003-RR-1-shot",
"text-davinci-003-RR-5-shot",
"text-davinci-003-QR-1-shot",
"text-davinci-003-QR-5-shot",
"gpt-3.5-turbo-0301-zeroshot",
"gpt-4-0314-zeroshot",
"gpt-4-0314-zeroshot-postprocess",
"MS-Translator",
"google-cloud",
"wmt-best",
]
),
}
)
@dataclass(frozen=True)
class GptMtConfig:
"""Config for gpt-MT models."""
path: str
base_model: str
prompt_strategy: str | None = None
prompt_shots: int | None = None
post_processors: list[Callable[[str], str]] | None = None
# The details of each model
model_configs = {
"text-davinci-003-RR-1-shot": GptMtConfig(
"text-davinci-003/RR/1-shot", "text-davinci-003", "RR", 1
),
"text-davinci-003-RR-5-shot": GptMtConfig(
"text-davinci-003/RR/5-shot", "text-davinci-003", "RR", 5
),
"text-davinci-003-QR-1-shot": GptMtConfig(
"text-davinci-003/QR/1-shot", "text-davinci-003", "QR", 1
),
"text-davinci-003-QR-5-shot": GptMtConfig(
"text-davinci-003/QR/5-shot", "text-davinci-003", "QR", 5
),
"text-davinci-003-zeroshot": GptMtConfig(
"text-davinci-003/zeroshot", "text-davinci-003", None, 0
),
"gpt-3.5-turbo-0301-zeroshot": GptMtConfig(
"gpt-3.5-turbo-0301/zeroshot", "gpt-3.5-turbo-0301", None, 0
),
"gpt-4-0314-zeroshot": GptMtConfig("gpt-4-0314/zeroshot", "gpt-4-0314", None, 0),
"gpt-4-0314-zeroshot-postprocess": GptMtConfig(
"gpt-4-0314/zeroshot", "gpt-4-0314", None, 0, [remove_leading_language]
),
"MS-Translator": GptMtConfig("MS-Translator", "MS-Translator"),
"google-cloud": GptMtConfig("google-cloud", "google-cloud"),
"wmt-best": GptMtConfig("wmt-best", "wmt-best"),
}
sweep_distill_functions = [chrf]
sweep_metric_function = avg_chrf
# The functions used for Zeno visualization
zeno_distill_and_metric_functions = [
output_length,
input_length,
label_length,
doc_context_length,
input_capital_char_ratio,
output_max_word_freq,
chrf,
comet,
length_ratio,
bert_score,
exact_match,
avg_chrf,
avg_comet,
avg_length_ratio,
avg_bert_score,
avg_exact_match,
]