File size: 283 Bytes
02443c1
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import pytorch_lightning as pl
from apps.project_model2 import UNet

class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UNet()
    def forward(self, data):
        pred = self.model(data)
        return pred
model=Segmenter()