TriEightz commited on
Commit
e344930
1 Parent(s): f85ffa7

Upload model

Browse files
Files changed (3) hide show
  1. ConvNextBase.py +43 -0
  2. config.json +4 -0
  3. pytorch_model.bin +1 -1
ConvNextBase.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ from typing import List
3
+ from torchvision.models import convnext_base, ConvNeXt_Base_Weights
4
+
5
+
6
+ class ConvNextBaseConfig(PretrainedConfig):
7
+ model_type = "ConvNext"
8
+
9
+ def __init__(
10
+ self,
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
14
+
15
+
16
+ class ConvNextBaseModel(PreTrainedModel):
17
+ config_class = ConvNextBaseConfig
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.model = convnext_base()
22
+
23
+ def forward(self, tensor):
24
+ return self.model(tensor)
25
+
26
+
27
+ class ConvNextBaseModelForImageClassification(PreTrainedModel):
28
+ config_class = ConvNextBaseConfig
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.model = convnext_base()
33
+
34
+ def forward(self, tensor, labels=None):
35
+ logits = self.model(tensor)
36
+ if labels is not None:
37
+ loss = torch.nn.cross_entropy(logits, labels)
38
+ return {"loss": loss, "logits": logits}
39
+ return {"logits": logits}
40
+
41
+ ConvNextBaseConfig.register_for_auto_class()
42
+ ConvNextBaseModel.register_for_auto_class("AutoModel")
43
+ ConvNextBaseModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "ConvNextBaseModelForImageClassification"
4
  ],
 
 
 
 
5
  "model_type": "ConvNext",
6
  "torch_dtype": "float32",
7
  "transformers_version": "4.24.0"
 
2
  "architectures": [
3
  "ConvNextBaseModelForImageClassification"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "ConvNextBase.ConvNextBaseConfig",
7
+ "AutoModelForImageClassification": "ConvNextBase.ConvNextBaseModelForImageClassification"
8
+ },
9
  "model_type": "ConvNext",
10
  "torch_dtype": "float32",
11
  "transformers_version": "4.24.0"
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0df9805053983ed8a274bdbea1a3321d2f4a2c69deaf8fe7821badb70349281d
3
  size 354474029
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c59b8953c27c63ced868f8a94fac3d90b0261c8e5a002078ec4cb202d19c85f
3
  size 354474029