import requests from PIL import Image import evaluate metric = evaluate.load("./clip_score.py") def download_image(image_path): if image_path.startswith("http"): image = Image.open(requests.get(image_path, stream=True).raw) else: image = Image.open(image_path) return image def compute_clip_score(image, text): if not isinstance(image, list): references = [image] else: references = image if not isinstance(text, list): predictions = [text] else: predictions = text results = metric.compute(predictions=predictions, references=references) return results["clip_score"] predictions = ["A cat sitting on a couch", "A scenic view of mountains during sunset"] references = [ "https://images.unsplash.com/photo-1720539222585-346e73f01536", "https://images.unsplash.com/photo-1694253987647-4eebcf679974", ] references = [download_image(url) for url in references] test_cases = [ { "predictions": predictions, "references": references, "result": {"clip_score": 0.307}, }, { "predictions": predictions[0], "references": references[0], "result": {"clip_score": 0.304}, }, { "predictions": predictions[1], "references": references[1], "result": {"clip_score": 0.310}, }, { "predictions": predictions[0], "references": references[1], "result": {"clip_score": 0.106}, }, { "predictions": predictions[1], "references": references[0], "result": {"clip_score": 0.134}, }, ] for i, test_case in enumerate(test_cases): result = compute_clip_score(test_case["references"], test_case["predictions"]) error = abs(result - test_case["result"]["clip_score"]) assert error < 0.1, f"Test case {i} failed"