Spaces:
Sleeping
Sleeping
File size: 3,216 Bytes
382191a 95b368a 382191a 95b368a 382191a 95b368a 382191a 95b368a 382191a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""Chatbots using API-based services."""
from __future__ import annotations
import os
import re
from dataclasses import dataclass
import config
@dataclass(frozen=True)
class GptMtInstance:
"""An instance from the GPT-MT dataset.
Attributes:
data: The input sentence.
label: The output sentence.
doc_id: The document ID.
lang_pair: The language pair.
"""
data: str
label: str
doc_id: str
lang_pair: str
def process_data(
input_dir: str,
lang_pairs: list[str],
) -> list[GptMtInstance]:
"""Load data."""
# Load the data
data: list[GptMtInstance] = []
eval_dir = os.path.join(input_dir, "evaluation", "testset")
for lang_pair in lang_pairs:
src_lang, trg_lang = lang_pair[:2], lang_pair[2:]
src_file = os.path.join(
eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{src_lang}"
)
trg_file = os.path.join(
eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}"
)
doc_file = os.path.join(
eval_dir,
"wmt-testset-docids",
lang_pair,
f"test.{src_lang}-{trg_lang}.docids",
)
with open(src_file, "r") as src_in, open(trg_file, "r") as trg_in, open(
doc_file, "r"
) as doc_in:
for src_line, trg_line, doc_line in zip(src_in, trg_in, doc_in):
data.append(
GptMtInstance(
src_line.strip(), trg_line.strip(), doc_line.strip(), lang_pair
)
)
return data
def remove_leading_language(line: str) -> str:
"""Remove a language at the beginning of the string.
Some zero-shot models output the name of the language at the beginning of the
string. This is a manual post-processing function that removes the language name
(partly as an example of how you can do simple fixes to issues that come up during
analysis using Zeno).
Args:
line: The line to process.
Returns:
The line with the language removed.
"""
return re.sub(
r"^(English|Japanese|Chinese|Hausa|Icelandic|French|German|Russian|Ukranian): ",
"",
line,
)
def process_output(
input_dir: str,
lang_pairs: list[str],
model_preset: str,
) -> list[str]:
"""Load model outputs."""
# Load the data
data: list[str] = []
model_config = config.model_configs[model_preset]
model_path = model_config.path
system_dir = os.path.join(input_dir, "evaluation", "system-outputs", model_path)
for lang_pair in lang_pairs:
src_lang, trg_lang = lang_pair[:2], lang_pair[2:]
sys_file = os.path.join(
system_dir, lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}"
)
with open(sys_file, "r") as sys_in:
for sys_line in sys_in:
sys_line = sys_line.strip()
if model_config.post_processors is not None:
for postprocessor in model_config.post_processors:
sys_line = postprocessor(sys_line)
data.append(sys_line)
return data
|