ljw20180420 commited on
Commit
0f6c5c5
1 Parent(s): 5893a5b

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +30 -0
pipeline.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline
3
+ import torch.nn.functional as F
4
+
5
+ class LindelPipeline(DiffusionPipeline):
6
+ def __init__(self, indel_model, ins_model, del_model):
7
+ super().__init__()
8
+
9
+ self.register_modules(indel_model=indel_model, ins_model=ins_model, del_model=del_model)
10
+ Lindel_dlen = int(round((-7 + (49 + 4 * (8 + 2 * self.del_model.linear.weight.shape[0])) ** 0.5) / 2))
11
+ self.dstarts, self.dends = [], []
12
+ for dlen in range(Lindel_dlen - 1, 0, -1):
13
+ for dstart in range(-dlen - 1, 3):
14
+ self.dstarts.append(dstart)
15
+ self.dends.append(dstart + dlen)
16
+
17
+ @torch.no_grad()
18
+ def __call__(self, batch):
19
+ indel_proba = F.softmax(self.indel_model(batch["input_indel"].to(self.indel_model.device))["logit"], dim=1)
20
+ ins_base_proba = F.softmax(self.ins_model(batch["input_ins"].to(self.ins_model.device))["logit"], dim=1)
21
+ del_pos_proba = F.softmax(self.del_model(batch["input_del"].to(self.del_model.device))["logit"], dim=1)
22
+ return {
23
+ "del_proba": indel_proba[:, 0],
24
+ "ins_proba": indel_proba[:, 1],
25
+ "ins_base": ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT", ">2"],
26
+ "ins_base_proba": ins_base_proba,
27
+ "dstart": self.dstarts,
28
+ "dend": self.dends,
29
+ "del_pos_proba": del_pos_proba
30
+ }