colbyford commited on
Commit
4a7c05c
1 Parent(s): 93de69d

ZeroGPU test 3

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -9,24 +9,19 @@ from evodiff.generate import generate_oaardm, generate_d3pm
9
  from evodiff.generate_msa import generate_query_oadm_msa_simple
10
  from evodiff.conditional_generation import inpaint_simple, generate_scaffold
11
 
12
- @spaces.GPU
13
- def get_device():
14
- if torch.cuda.is_available():
15
- return "cuda"
16
- else:
17
- return "cpu"
18
-
19
- @spaces.GPU
20
  def make_uncond_seq(seq_len, model_type):
21
  if model_type == "EvoDiff-Seq-OADM 38M":
22
  checkpoint = OA_DM_38M()
23
  model, collater, tokenizer, scheme = checkpoint
24
- tokeinzed_sample, generated_sequence = generate_oaardm(model, tokenizer, int(seq_len), batch_size=1, device=get_device())
25
 
26
  if model_type == "EvoDiff-D3PM-Uniform 38M":
27
  checkpoint = D3PM_UNIFORM_38M(return_all=True)
28
  model, collater, tokenizer, scheme, timestep, Q_bar, Q = checkpoint
29
- tokeinzed_sample, generated_sequence = generate_d3pm(model, tokenizer, Q, Q_bar, timestep, int(seq_len), batch_size=1, device=get_device())
30
 
31
  return generated_sequence
32
 
@@ -35,7 +30,7 @@ def make_cond_seq(seq_len, msa_file, n_sequences, model_type):
35
  checkpoint = MSA_OA_DM_MAXSUB()
36
  model, collater, tokenizer, scheme = checkpoint
37
  print(f"MSA File Path: {msa_file.name}")
38
- tokeinzed_sample, generated_sequence = generate_query_oadm_msa_simple(msa_file.name, model, tokenizer, int(n_sequences), seq_length=int(seq_len), device=get_device(), selection_type='random')
39
 
40
  return generated_sequence
41
 
@@ -43,7 +38,7 @@ def make_inpainted_idrs(sequence, start_idx, end_idx, model_type):
43
  if model_type == "EvoDiff-Seq":
44
  checkpoint = OA_DM_38M()
45
  model, collater, tokenizer, scheme = checkpoint
46
- sample, entire_sequence, generated_idr = inpaint_simple(model, sequence, int(start_idx), int(end_idx), tokenizer=tokenizer, device=get_device())
47
 
48
  generated_idr_output = {
49
  "original_sequence": sequence,
@@ -63,7 +58,7 @@ def make_inpainted_idrs(sequence, start_idx, end_idx, model_type):
63
  # # print("Folders in User Cache Directory:", os.listdir("/home/user/.cache"))
64
  # start_idx = list(map(int, start_idx.strip('][').split(',')))
65
  # end_idx = list(map(int, end_idx.strip('][').split(',')))
66
- # generated_sequence, new_start_idx, new_end_idx = generate_scaffold(model, pdb_code, start_idx, end_idx, scaffold_length, data_top_dir, tokenizer, device=get_device())
67
 
68
  # generated_scaffold_output = {
69
  # "generated_sequence": generated_sequence,
 
9
  from evodiff.generate_msa import generate_query_oadm_msa_simple
10
  from evodiff.conditional_generation import inpaint_simple, generate_scaffold
11
 
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ @spaces.GPU()
 
 
 
 
 
15
  def make_uncond_seq(seq_len, model_type):
16
  if model_type == "EvoDiff-Seq-OADM 38M":
17
  checkpoint = OA_DM_38M()
18
  model, collater, tokenizer, scheme = checkpoint
19
+ tokeinzed_sample, generated_sequence = generate_oaardm(model, tokenizer, int(seq_len), batch_size=1, device=device)
20
 
21
  if model_type == "EvoDiff-D3PM-Uniform 38M":
22
  checkpoint = D3PM_UNIFORM_38M(return_all=True)
23
  model, collater, tokenizer, scheme, timestep, Q_bar, Q = checkpoint
24
+ tokeinzed_sample, generated_sequence = generate_d3pm(model, tokenizer, Q, Q_bar, timestep, int(seq_len), batch_size=1, device=device)
25
 
26
  return generated_sequence
27
 
 
30
  checkpoint = MSA_OA_DM_MAXSUB()
31
  model, collater, tokenizer, scheme = checkpoint
32
  print(f"MSA File Path: {msa_file.name}")
33
+ tokeinzed_sample, generated_sequence = generate_query_oadm_msa_simple(msa_file.name, model, tokenizer, int(n_sequences), seq_length=int(seq_len), device=device, selection_type='random')
34
 
35
  return generated_sequence
36
 
 
38
  if model_type == "EvoDiff-Seq":
39
  checkpoint = OA_DM_38M()
40
  model, collater, tokenizer, scheme = checkpoint
41
+ sample, entire_sequence, generated_idr = inpaint_simple(model, sequence, int(start_idx), int(end_idx), tokenizer=tokenizer, device=device)
42
 
43
  generated_idr_output = {
44
  "original_sequence": sequence,
 
58
  # # print("Folders in User Cache Directory:", os.listdir("/home/user/.cache"))
59
  # start_idx = list(map(int, start_idx.strip('][').split(',')))
60
  # end_idx = list(map(int, end_idx.strip('][').split(',')))
61
+ # generated_sequence, new_start_idx, new_end_idx = generate_scaffold(model, pdb_code, start_idx, end_idx, scaffold_length, data_top_dir, tokenizer, device=device)
62
 
63
  # generated_scaffold_output = {
64
  # "generated_sequence": generated_sequence,