Spaces:
Runtime error
Runtime error
File size: 5,718 Bytes
54199b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import requests
import torch
from PIL import Image
import hashlib
import tempfile
import unittest
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
from urllib3 import HTTPResponse
from urllib3._collections import HTTPHeaderDict
import open_clip
from open_clip.pretrained import download_pretrained_from_url
class DownloadPretrainedTests(unittest.TestCase):
def create_response(self, data, status_code=200, content_type='application/octet-stream'):
fp = BytesIO(data)
headers = HTTPHeaderDict({
'Content-Type': content_type,
'Content-Length': str(len(data))
})
raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code)
return raw
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(file_contents)
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_not_called()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(b'corrupted pretrained model')
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_mlfoundations(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_hfh(self, urllib):
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model')
tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model')
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3))
|