File size: 1,955 Bytes
7fe0374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule

optimizer:
  _target_: torch.optim.AdamW
  _partial_: true
  lr: 0.001
  weight_decay: 0.2

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 5

net:
  _target_: src.models.components.anomaly_clip.AnomalyCLIP
  arch: ViT-L/14@336px
  image_size: 336
  class_names: ["object"]
  # class_names: ${prompt.class_names}
  temperature: 0.07                     # softmax
  prompt_length: 12                     # length of learnable prompts
  context_length: 77                    # defaut 77 for openai clip
  truncate: false
  feature_map_idx: [5, 11, 17, 23]      # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
  share_weight: false                   # whether the adapter shares weights for different feature maps
  # state_template: ${prompt.state_template}
  state_template: 
    normal: ["{}"]
    anomaly: ["damaged {}"]
  tokenizer:
    _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
  adapter:
    # _target_: torch.nn.Linear
    # in_features: 1024                   # clip vit feature dim, defaut 1024 for openai clip
    # out_features: 1024
    # bias: false
    _target_: src.models.components.adapter.Linear
    in_features: 1024                     # clip vit feature dim, defaut 1024 for openai clip
    out_features: 1024
    hidden_features: null                 # set null, same as nn.Linear
    dropout_prob: 0.0
    bias: false
  fusion:
    _target_: src.models.components.cross_modal.DotProductFusion
  embedding_dim: null                     # clip fusion featrue dim, only effective for learnable

loss:
  cross_entropy:
    _target_: torch.nn.CrossEntropyLoss
  focal:
    _target_: src.models.components.loss.FocalLoss
  dice:
    _target_: src.models.components.loss.BinaryDiceLoss

k_shot: false

filter: true

enable_validation: false

compile: false