File size: 1,544 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from inference.core.env import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, LAMBDA
from inference.core.models.classification_base import (
    ClassificationBaseOnnxRoboflowInferenceModel,
)


class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel):
    """VitClassification handles classification inference
    for Vision Transformer (ViT) models using ONNX.

    Inherits:
        ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference.
        ClassificationMixin: Mixin class providing classification-specific methods.

    Attributes:
        multiclass (bool): A flag that specifies if the model should handle multiclass classification.
    """

    def __init__(self, *args, **kwargs):
        """Initializes the VitClassification instance.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.multiclass = self.environment.get("MULTICLASS", False)

    @property
    def weights_file(self) -> str:
        """Determines the weights file to be used based on the availability of AWS keys.

        If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'.
        Otherwise, it returns the path to 'best.onnx'.

        Returns:
            str: Path to the weights file.
        """
        if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA:
            return "weights.onnx"
        else:
            return "best.onnx"