File size: 413 Bytes
be053b4
 
 
 
 
780954b
be053b4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from abc import ABC, abstractmethod


class AbstractLLMModel(ABC):
    def __init__(
        self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
    ):
        print(f"Loading LLM model {model_id}...")
        self.model_id = model_id
        self.device = device
        self.cache_dir = cache_dir

    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> str:
        pass