wenkai commited on
Commit
1a0324b
1 Parent(s): cdf31f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -9,17 +9,20 @@ import spaces
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
 
12
  # from transformers import EsmTokenizer, EsmModel
13
 
14
 
15
  # Load the model
16
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
17
- model.load_checkpoint("model/checkpoint_mf2.pth")
18
- model.to('cuda')
 
 
 
 
 
19
 
20
- model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
21
- model_esm.to('cuda')
22
- model_esm.eval()
23
  # tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
24
  # model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
25
  # model_esm.to('cuda')
@@ -32,22 +35,26 @@ def generate_caption(protein, prompt):
32
  # f.write('>{}\n'.format("protein_name"))
33
  # f.write('{}\n'.format(protein.strip()))
34
  # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
35
- # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
36
- # model=model_esm, alphabet=alphabet,
37
  # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
38
-
39
- protein_name='protein_name'
40
- protein_seq=protein
41
- include='per_tok'
42
- repr_layers=[36]
43
- truncation_seq_length=1024
44
- toks_per_batch=4096
45
  print("start")
46
  dataset = FastaBatchedDataset([protein_name], [protein_seq])
47
  print("dataset prepared")
48
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
49
  print("batches prepared")
50
-
 
 
 
 
51
  data_loader = torch.utils.data.DataLoader(
52
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
53
  )
@@ -78,12 +85,12 @@ def generate_caption(protein, prompt):
78
  # See https://github.com/pytorch/pytorch/issues/1995
79
  if "per_tok" in include:
80
  result["representations"] = {
81
- layer: t[i, 1 : truncate_len + 1].clone()
82
  for layer, t in representations.items()
83
  }
84
  if "mean" in include:
85
  result["mean_representations"] = {
86
- layer: t[i, 1 : truncate_len + 1].mean(0).clone()
87
  for layer, t in representations.items()
88
  }
89
  if "bos" in include:
@@ -106,18 +113,25 @@ def generate_caption(protein, prompt):
106
  'image': torch.unsqueeze(esm_emb, dim=0),
107
  'text_input': ['none'],
108
  'prompt': [prompt]}
 
 
 
 
 
 
109
  # Generate the output
110
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
 
111
 
112
  return prediction
113
  # return "test"
114
 
 
115
  # Define the FAPM interface
116
  description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
117
 
118
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
119
 
120
-
121
  iface = gr.Interface(
122
  fn=generate_caption,
123
  inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
 
9
  import gradio as gr
10
  from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
+
13
  # from transformers import EsmTokenizer, EsmModel
14
 
15
 
16
  # Load the model
17
+ # model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
+ # model.load_checkpoint("model/checkpoint_mf2.pth")
19
+ # model.to('cuda')
20
+
21
+ # model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
22
+ # model_esm.to('cuda')
23
+ # model_esm.eval()
24
+
25
 
 
 
 
26
  # tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
27
  # model_esm = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
28
  # model_esm.to('cuda')
 
35
  # f.write('>{}\n'.format("protein_name"))
36
  # f.write('{}\n'.format(protein.strip()))
37
  # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
38
+ # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
39
+ # model=model_esm, alphabet=alphabet,
40
  # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
41
+
42
+ protein_name = 'protein_name'
43
+ protein_seq = protein
44
+ include = 'per_tok'
45
+ repr_layers = [36]
46
+ truncation_seq_length = 1024
47
+ toks_per_batch = 4096
48
  print("start")
49
  dataset = FastaBatchedDataset([protein_name], [protein_seq])
50
  print("dataset prepared")
51
  batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
52
  print("batches prepared")
53
+
54
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
55
+ model_esm.to('cuda')
56
+ model_esm.eval()
57
+
58
  data_loader = torch.utils.data.DataLoader(
59
  dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
60
  )
 
85
  # See https://github.com/pytorch/pytorch/issues/1995
86
  if "per_tok" in include:
87
  result["representations"] = {
88
+ layer: t[i, 1: truncate_len + 1].clone()
89
  for layer, t in representations.items()
90
  }
91
  if "mean" in include:
92
  result["mean_representations"] = {
93
+ layer: t[i, 1: truncate_len + 1].mean(0).clone()
94
  for layer, t in representations.items()
95
  }
96
  if "bos" in include:
 
113
  'image': torch.unsqueeze(esm_emb, dim=0),
114
  'text_input': ['none'],
115
  'prompt': [prompt]}
116
+
117
+ del model_esm
118
+
119
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
120
+ model.load_checkpoint("model/checkpoint_mf2.pth")
121
+ model.to('cuda')
122
  # Generate the output
123
+ prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
124
+ repetition_penalty=1.0)
125
 
126
  return prediction
127
  # return "test"
128
 
129
+
130
  # Define the FAPM interface
131
  description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
132
 
133
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
134
 
 
135
  iface = gr.Interface(
136
  fn=generate_caption,
137
  inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],