File size: 3,216 Bytes
382191a
 
 
 
95b368a
382191a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b368a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382191a
 
 
 
 
 
 
 
95b368a
 
382191a
 
 
 
 
 
 
 
95b368a
 
 
 
 
382191a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Chatbots using API-based services."""
from __future__ import annotations

import os
import re
from dataclasses import dataclass

import config


@dataclass(frozen=True)
class GptMtInstance:
    """An instance from the GPT-MT dataset.

    Attributes:
        data: The input sentence.
        label: The output sentence.
        doc_id: The document ID.
        lang_pair: The language pair.
    """

    data: str
    label: str
    doc_id: str
    lang_pair: str


def process_data(
    input_dir: str,
    lang_pairs: list[str],
) -> list[GptMtInstance]:
    """Load data."""
    # Load the data
    data: list[GptMtInstance] = []
    eval_dir = os.path.join(input_dir, "evaluation", "testset")
    for lang_pair in lang_pairs:
        src_lang, trg_lang = lang_pair[:2], lang_pair[2:]
        src_file = os.path.join(
            eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{src_lang}"
        )
        trg_file = os.path.join(
            eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}"
        )
        doc_file = os.path.join(
            eval_dir,
            "wmt-testset-docids",
            lang_pair,
            f"test.{src_lang}-{trg_lang}.docids",
        )
        with open(src_file, "r") as src_in, open(trg_file, "r") as trg_in, open(
            doc_file, "r"
        ) as doc_in:
            for src_line, trg_line, doc_line in zip(src_in, trg_in, doc_in):
                data.append(
                    GptMtInstance(
                        src_line.strip(), trg_line.strip(), doc_line.strip(), lang_pair
                    )
                )
    return data


def remove_leading_language(line: str) -> str:
    """Remove a language at the beginning of the string.

    Some zero-shot models output the name of the language at the beginning of the
    string. This is a manual post-processing function that removes the language name
    (partly as an example of how you can do simple fixes to issues that come up during
    analysis using Zeno).

    Args:
        line: The line to process.

    Returns:
        The line with the language removed.
    """
    return re.sub(
        r"^(English|Japanese|Chinese|Hausa|Icelandic|French|German|Russian|Ukranian): ",
        "",
        line,
    )


def process_output(
    input_dir: str,
    lang_pairs: list[str],
    model_preset: str,
) -> list[str]:
    """Load model outputs."""
    # Load the data
    data: list[str] = []
    model_config = config.model_configs[model_preset]
    model_path = model_config.path
    system_dir = os.path.join(input_dir, "evaluation", "system-outputs", model_path)
    for lang_pair in lang_pairs:
        src_lang, trg_lang = lang_pair[:2], lang_pair[2:]
        sys_file = os.path.join(
            system_dir, lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}"
        )
        with open(sys_file, "r") as sys_in:
            for sys_line in sys_in:
                sys_line = sys_line.strip()
                if model_config.post_processors is not None:
                    for postprocessor in model_config.post_processors:
                        sys_line = postprocessor(sys_line)
                data.append(sys_line)
    return data