File size: 466 Bytes
7dbe662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch import nn
from typing import Union, Dict
from dataclasses import dataclass

@dataclass(repr=True)
class Output:
    pass

class Pipeline(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Pipeline, self).__init__()
        self.args = args
        self.kwargs = kwargs
    @classmethod
    def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
        pass
    def forward(self, *args, **kwargs):
        pass