train-scripts / comet_api.py
Ashton2000's picture
Upload folder using huggingface_hub
981b783 verified
import argparse
import requests
import time
import os
def get_comet_score(instances: list[dict], timeout=100, max_retries=10, comet_api: str=None):
if comet_api is not None:
url = f"http://{comet_api}/evaluate"
else:
url = f"http://{os.getenv('COMET_API')}/evaluate"
payload = {'instances': instances}
retries = 0
while retries < max_retries:
try:
response = requests.post(url, json=payload, timeout=timeout)
if response.status_code == 200:
# print(response.json()) # {'score': ...}
return response.json()['scores']
else:
print(f"Request failed with status code: {response.status_code}")
except requests.Timeout:
retries += 1
print(f"Request timed out. Retrying... ({retries}/{max_retries})")
time.sleep(5)
except requests.RequestException as e:
raise RuntimeError(f"Request failed due to: {e}")
raise RuntimeError("Max retries exceeded. Request failed.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--source_file', '-s', type=str, required=True)
parser.add_argument('--target_file', '-t', type=str, required=True)
parser.add_argument('--reference_file', '-r', type=str, required=True)
parser.add_argument('--url', '-u', type=str, required=True)
args = parser.parse_args()
source_file = args.source_file
target_file = args.target_file
reference_file = args.reference_file
comet_api = args.url
with open(source_file, 'r') as f:
source_lines = f.readlines()
with open(target_file, 'r') as f:
target_lines = f.readlines()
with open(reference_file, 'r') as f:
reference_lines = f.readlines()
line_comet_scores = get_comet_score([{'src': i, 'mt': j, 'ref': k} for i, j, k in zip(source_lines, target_lines, reference_lines)], comet_api=comet_api)
avg_score = sum(line_comet_scores) / len(line_comet_scores) if line_comet_scores and len(line_comet_scores) > 0 else -1.0
print(f'{target_file}\tscore: {avg_score:.4f}')
if __name__ == '__main__':
main()