wenkai commited on
Commit
9cc264c
1 Parent(s): 5b42959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -15,6 +15,9 @@ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
15
  model.load_checkpoint("model/checkpoint_mf2.pth")
16
  model.to('cuda')
17
 
 
 
 
18
 
19
  @spaces.GPU
20
  def generate_caption(protein, prompt):
@@ -37,19 +40,15 @@ def generate_caption(protein, prompt):
37
  print("dataset prepared")
38
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
39
  print("batches prepared")
40
-
41
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
42
- model_esm.to('cuda')
43
- model_esm.eval()
44
 
45
  data_loader = torch.utils.data.DataLoader(
46
- dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
47
  )
48
  print(f"Read sequences")
49
  return_contacts = "contacts" in include
50
 
51
- assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
52
- repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
53
 
54
  with torch.no_grad():
55
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
@@ -58,8 +57,7 @@ def generate_caption(protein, prompt):
58
  )
59
  if torch.cuda.is_available():
60
  toks = toks.to(device="cuda", non_blocking=True)
61
- out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
62
- del model_esm
63
  logits = out["logits"].to(device="cpu")
64
  representations = {
65
  layer: t.to(device="cpu") for layer, t in out["representations"].items()
 
15
  model.load_checkpoint("model/checkpoint_mf2.pth")
16
  model.to('cuda')
17
 
18
+ # model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
19
+ # model_esm.to('cuda')
20
+ # model_esm.eval()
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
 
40
  print("dataset prepared")
41
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
42
  print("batches prepared")
 
 
 
 
43
 
44
  data_loader = torch.utils.data.DataLoader(
45
+ dataset, collate_fn=model.alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
46
  )
47
  print(f"Read sequences")
48
  return_contacts = "contacts" in include
49
 
50
+ assert all(-(model.model_esm.num_layers + 1) <= i <= model.model_esm.num_layers for i in repr_layers)
51
+ repr_layers = [(i + model.model_esm.num_layers + 1) % (model.model_esm.num_layers + 1) for i in repr_layers]
52
 
53
  with torch.no_grad():
54
  for batch_idx, (labels, strs, toks) in enumerate(data_loader):
 
57
  )
58
  if torch.cuda.is_available():
59
  toks = toks.to(device="cuda", non_blocking=True)
60
+ out = model.model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
 
61
  logits = out["logits"].to(device="cpu")
62
  representations = {
63
  layer: t.to(device="cpu") for layer, t in out["representations"].items()