| import os | |
| import torch | |
| from comet import download_model, load_from_checkpoint | |
| def calculate_comet(source_sentences, translations, references): | |
| """ | |
| Calculate COMET scores using the local COMET installation. | |
| :param source_sentences: List of source sentences | |
| :param translations: List of translated sentences | |
| :param references: List of reference translations | |
| :return: List of COMET scores | |
| """ | |
| try: | |
| # Download and load the COMET model | |
| # Set cache directory explicitly | |
| os.environ["COMET_CACHE"] = "/tmp" | |
| # Download and load the COMET model | |
| model_path = download_model("Unbabel/wmt22-comet-da") | |
| model = load_from_checkpoint(model_path) | |
| # Check for GPU availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # Prepare data in COMET format | |
| data = [ | |
| { | |
| "src": src, | |
| "mt": mt, | |
| "ref": ref | |
| } | |
| for src, mt, ref in zip(source_sentences, translations, references) | |
| ] | |
| # Get predictions (use GPU if available) | |
| results = model.predict(data, batch_size=8, gpus=1 if device == "cuda" else 0) | |
| return results["scores"] | |
| except Exception as e: | |
| print(f"COMET Error: {str(e)}") | |
| return [0.0] * len(source_sentences) |