File size: 2,920 Bytes
8896a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os, sys
import torch

from .models.embedding import FullyConnectedEmbed, SkipLSTM
from .models.contact import ContactCNN
from .models.interaction import ModelInteraction


def build_lm_1(state_dict_path):
    """
    :meta private:
    """
    model = SkipLSTM(21, 100, 1024, 3)
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


def build_human_1(state_dict_path):
    """
    :meta private:
    """
    embModel = FullyConnectedEmbed(6165, 100, 0.5)
    conModel = ContactCNN(100, 50, 7)
    model = ModelInteraction(embModel, conModel, use_W=True, pool_size=9)
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


VALID_MODELS = {
        "lm_v1": build_lm_1,
        "human_v1": build_human_1
        }


def get_state_dict(version="human_v1", verbose=True):
    """
    Download a pre-trained model if not already exists on local device.

    :param version: Version of trained model to download [default: human_1]
    :type version: str
    :param verbose: Print model download status on stdout [default: True]
    :type verbose: bool
    :return: Path to state dictionary for pre-trained language model
    :rtype: str
    """
    state_dict_basename = f"dscript_{version}.pt"
    state_dict_basedir = os.path.dirname(os.path.realpath(__file__))
    state_dict_fullname = f"{state_dict_basedir}/{state_dict_basename}"
    state_dict_url = f"http://cb.csail.mit.edu/cb/dscript/data/models/{state_dict_basename}"
    if not os.path.exists(state_dict_fullname):
        try:
            import urllib.request
            import shutil
            if verbose: print(f"Downloading model {version} from {state_dict_url}...")
            with urllib.request.urlopen(state_dict_url) as response, open(state_dict_fullname, 'wb') as out_file:
                shutil.copyfileobj(response, out_file)
        except Exception as e:
            print("Unable to download model - {}".format(e))
            sys.exit(1)
    return state_dict_fullname


def get_pretrained(version="human_v1"):
    """
    Get pre-trained model object.

    Currently Available Models
    ==========================

    See the `documentation <https://d-script.readthedocs.io/en/main/data.html#trained-models>`_ for most up-to-date list.

    - ``lm_v1`` - Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
    - ``human_v1`` - Human trained model from D-SCRIPT manuscript.

    Default: ``human_v1``

    :param version: Version of pre-trained model to get
    :type version: str
    :return: Pre-trained model
    :rtype: dscript.models.*
    """
    if not version in VALID_MODELS:
        raise ValueError("Model {} does not exist".format(version))

    state_dict_path = get_state_dict(version)
    return VALID_MODELS[version](state_dict_path)