|
|
import os |
|
|
import torch |
|
|
from comet import download_model, load_from_checkpoint |
|
|
|
|
|
|
|
|
os.environ["COMET_CACHE"] = "/tmp" |
|
|
|
|
|
def calculate_comet(source_sentences, translations, references): |
|
|
""" |
|
|
Calculate COMET scores for a list of translations. |
|
|
:param source_sentences: List of source sentences. |
|
|
:param translations: List of translated sentences (hypotheses). |
|
|
:param references: List of reference translations. |
|
|
:return: List of COMET scores (one score per sentence pair). |
|
|
""" |
|
|
try: |
|
|
|
|
|
model_path = download_model("Unbabel/wmt22-comet-da") |
|
|
model = load_from_checkpoint(model_path) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
data = [ |
|
|
{"src": src, "mt": mt, "ref": ref} |
|
|
for src, mt, ref in zip(source_sentences, translations, references) |
|
|
] |
|
|
|
|
|
|
|
|
results = model.predict(data, batch_size=8, gpus=0) |
|
|
scores = results["scores"] |
|
|
return scores |
|
|
except Exception as e: |
|
|
print(f"COMET calculation error: {str(e)}") |
|
|
return [0.0] * len(source_sentences) |