"""Chatbots using API-based services.""" from __future__ import annotations import os import re from dataclasses import dataclass import config @dataclass(frozen=True) class GptMtInstance: """An instance from the GPT-MT dataset. Attributes: data: The input sentence. label: The output sentence. doc_id: The document ID. lang_pair: The language pair. """ data: str label: str doc_id: str lang_pair: str def process_data( input_dir: str, lang_pairs: list[str], ) -> list[GptMtInstance]: """Load data.""" # Load the data data: list[GptMtInstance] = [] eval_dir = os.path.join(input_dir, "evaluation", "testset") for lang_pair in lang_pairs: src_lang, trg_lang = lang_pair[:2], lang_pair[2:] src_file = os.path.join( eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{src_lang}" ) trg_file = os.path.join( eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}" ) doc_file = os.path.join( eval_dir, "wmt-testset-docids", lang_pair, f"test.{src_lang}-{trg_lang}.docids", ) with open(src_file, "r") as src_in, open(trg_file, "r") as trg_in, open( doc_file, "r" ) as doc_in: for src_line, trg_line, doc_line in zip(src_in, trg_in, doc_in): data.append( GptMtInstance( src_line.strip(), trg_line.strip(), doc_line.strip(), lang_pair ) ) return data def remove_leading_language(line: str) -> str: """Remove a language at the beginning of the string. Some zero-shot models output the name of the language at the beginning of the string. This is a manual post-processing function that removes the language name (partly as an example of how you can do simple fixes to issues that come up during analysis using Zeno). Args: line: The line to process. Returns: The line with the language removed. """ return re.sub( r"^(English|Japanese|Chinese|Hausa|Icelandic|French|German|Russian|Ukranian): ", "", line, ) def process_output( input_dir: str, lang_pairs: list[str], model_preset: str, ) -> list[str]: """Load model outputs.""" # Load the data data: list[str] = [] model_config = config.model_configs[model_preset] model_path = model_config.path system_dir = os.path.join(input_dir, "evaluation", "system-outputs", model_path) for lang_pair in lang_pairs: src_lang, trg_lang = lang_pair[:2], lang_pair[2:] sys_file = os.path.join( system_dir, lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}" ) with open(sys_file, "r") as sys_in: for sys_line in sys_in: sys_line = sys_line.strip() if model_config.post_processors is not None: for postprocessor in model_config.post_processors: sys_line = postprocessor(sys_line) data.append(sys_line) return data