cerulianx commited on
Commit
217c009
1 Parent(s): 4c405ef

Upload __init__.py

Browse files
Files changed (1) hide show
  1. __init__.py +18 -0
__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, requests
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from dall_e.encoder import Encoder
6
+ from dall_e.decoder import Decoder
7
+ from dall_e.utils import map_pixels, unmap_pixels
8
+
9
+ def load_model(path: str, device: torch.device = None) -> nn.Module:
10
+ if path.startswith('http://') or path.startswith('https://'):
11
+ resp = requests.get(path)
12
+ resp.raise_for_status()
13
+
14
+ with io.BytesIO(resp.content) as buf:
15
+ return torch.load(buf, map_location=device)
16
+ else:
17
+ with open(path, 'rb') as f:
18
+ return torch.load(f, map_location=device)