Upload __init__.py
Browse files- __init__.py +18 -0
__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io, requests
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from dall_e.encoder import Encoder
|
6 |
+
from dall_e.decoder import Decoder
|
7 |
+
from dall_e.utils import map_pixels, unmap_pixels
|
8 |
+
|
9 |
+
def load_model(path: str, device: torch.device = None) -> nn.Module:
|
10 |
+
if path.startswith('http://') or path.startswith('https://'):
|
11 |
+
resp = requests.get(path)
|
12 |
+
resp.raise_for_status()
|
13 |
+
|
14 |
+
with io.BytesIO(resp.content) as buf:
|
15 |
+
return torch.load(buf, map_location=device)
|
16 |
+
else:
|
17 |
+
with open(path, 'rb') as f:
|
18 |
+
return torch.load(f, map_location=device)
|