dall-e / __init__.py
cerulianx's picture
Upload __init__.py
217c009
raw
history blame
595 Bytes
import io, requests
import torch
import torch.nn as nn
from dall_e.encoder import Encoder
from dall_e.decoder import Decoder
from dall_e.utils import map_pixels, unmap_pixels
def load_model(path: str, device: torch.device = None) -> nn.Module:
if path.startswith('http://') or path.startswith('https://'):
resp = requests.get(path)
resp.raise_for_status()
with io.BytesIO(resp.content) as buf:
return torch.load(buf, map_location=device)
else:
with open(path, 'rb') as f:
return torch.load(f, map_location=device)