Update esm_scripts/extract.py
Browse files- esm_scripts/extract.py +65 -0
esm_scripts/extract.py
CHANGED
@@ -131,6 +131,71 @@ def run(args):
|
|
131 |
)
|
132 |
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def main():
|
135 |
parser = create_parser()
|
136 |
args = parser.parse_args()
|
|
|
131 |
)
|
132 |
|
133 |
|
134 |
+
def run_demo(model_location, fasta_file, output_dir, include, nogpu,
|
135 |
+
repr_layers=-1, truncation_seq_length=1022, toks_per_batch=4096):
|
136 |
+
model, alphabet = pretrained.load_model_and_alphabet(model_location)
|
137 |
+
model.eval()
|
138 |
+
if isinstance(model, MSATransformer):
|
139 |
+
raise ValueError(
|
140 |
+
"This script currently does not handle models with MSA input (MSA Transformer)."
|
141 |
+
)
|
142 |
+
if torch.cuda.is_available() and not nogpu:
|
143 |
+
model = model.cuda()
|
144 |
+
print("Transferred model to GPU")
|
145 |
+
|
146 |
+
dataset = FastaBatchedDataset.from_file(fasta_file)
|
147 |
+
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
148 |
+
data_loader = torch.utils.data.DataLoader(
|
149 |
+
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
150 |
+
)
|
151 |
+
print(f"Read {fasta_file} with {len(dataset)} sequences")
|
152 |
+
|
153 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
154 |
+
return_contacts = "contacts" in include
|
155 |
+
|
156 |
+
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
|
157 |
+
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
161 |
+
print(
|
162 |
+
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
163 |
+
)
|
164 |
+
if torch.cuda.is_available() and not nogpu:
|
165 |
+
toks = toks.to(device="cuda", non_blocking=True)
|
166 |
+
|
167 |
+
out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
168 |
+
|
169 |
+
logits = out["logits"].to(device="cpu")
|
170 |
+
representations = {
|
171 |
+
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
172 |
+
}
|
173 |
+
if return_contacts:
|
174 |
+
contacts = out["contacts"].to(device="cpu")
|
175 |
+
|
176 |
+
for i, label in enumerate(labels):
|
177 |
+
result = {"label": label}
|
178 |
+
truncate_len = min(truncation_seq_length, len(strs[i]))
|
179 |
+
# Call clone on tensors to ensure tensors are not views into a larger representation
|
180 |
+
# See https://github.com/pytorch/pytorch/issues/1995
|
181 |
+
if "per_tok" in include:
|
182 |
+
result["representations"] = {
|
183 |
+
layer: t[i, 1 : truncate_len + 1].clone()
|
184 |
+
for layer, t in representations.items()
|
185 |
+
}
|
186 |
+
if "mean" in include:
|
187 |
+
result["mean_representations"] = {
|
188 |
+
layer: t[i, 1 : truncate_len + 1].mean(0).clone()
|
189 |
+
for layer, t in representations.items()
|
190 |
+
}
|
191 |
+
if "bos" in include:
|
192 |
+
result["bos_representations"] = {
|
193 |
+
layer: t[i, 0].clone() for layer, t in representations.items()
|
194 |
+
}
|
195 |
+
if return_contacts:
|
196 |
+
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
197 |
+
|
198 |
+
|
199 |
def main():
|
200 |
parser = create_parser()
|
201 |
args = parser.parse_args()
|