|
import io, requests |
|
import torch |
|
import torch.nn as nn |
|
|
|
from dall_e.encoder import Encoder |
|
from dall_e.decoder import Decoder |
|
from dall_e.utils import map_pixels, unmap_pixels |
|
|
|
def load_model(path: str, device: torch.device = None) -> nn.Module: |
|
if path.startswith('http://') or path.startswith('https://'): |
|
resp = requests.get(path) |
|
resp.raise_for_status() |
|
|
|
with io.BytesIO(resp.content) as buf: |
|
return torch.load(buf, map_location=device) |
|
else: |
|
with open(path, 'rb') as f: |
|
return torch.load(f, map_location=device) |
|
|