quandao92's picture
Update README.md
04a90cf verified
|
raw
history blame
11.8 kB

CLIP ๊ธฐ๋ฐ˜ ์ œํ’ˆ ๊ฒฐํ•จ ํƒ์ง€ ๋ชจ๋ธ ์นด๋“œ

๋ชจ๋ธ ์„ธ๋ถ€์‚ฌํ•ญ

๋ชจ๋ธ ์„ค๋ช…

์ด ๋ชจ๋ธ์€ CLIP ๊ธฐ๋ฐ˜์˜ ์ด์ƒ ํƒ์ง€ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ์ œํ’ˆ ๊ฒฐํ•จ์„ ํƒ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์‚ฌ์ „ ํ›ˆ๋ จ๋œ CLIP ๋ชจ๋ธ์„ fine-tuningํ•˜์—ฌ ์ œํ’ˆ ์ด๋ฏธ์ง€์—์„œ ๊ฒฐํ•จ์„ ์‹๋ณ„ํ•˜๊ณ , ์ƒ์‚ฐ ๋ผ์ธ์—์„œ ํ’ˆ์งˆ ๊ด€๋ฆฌ ๋ฐ ๊ฒฐํ•จ ๊ฐ์ง€๋ฅผ ์ž๋™ํ™”ํ•ฉ๋‹ˆ๋‹ค.

  • Developed by: ์˜ค์„
  • Funded by: 4INLAB INC.
  • Shared by: zhou2023anomalyclip
  • Model type: CLIP based Anomaly Detection
  • Language(s): Python, PyTorch
  • License: Apache 2.0, MIT, GPL-3.0

๊ธฐ์ˆ ์  ์ œํ•œ์‚ฌํ•ญ

  • ๋ชจ๋ธ์€ ๊ฒฐํ•จ ํƒ์ง€๋ฅผ ์œ„ํ•œ ์ถฉ๋ถ„ํ•˜๊ณ  ๋‹ค์–‘ํ•œ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋ฅผ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹์ด ๋ถ€์กฑํ•˜๊ฑฐ๋‚˜ ๋ถˆ๊ท ํ˜•ํ•  ๊ฒฝ์šฐ, ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ์ €ํ•˜๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์‹ค์‹œ๊ฐ„ ๊ฒฐํ•จ ๊ฐ์ง€ ์„ฑ๋Šฅ์€ ํ•˜๋“œ์›จ์–ด ์‚ฌ์–‘์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋†’์€ ํ•ด์ƒ๋„์—์„œ ๊ฒฐํ•จ์„ ํƒ์ง€ํ•˜๋Š” ์ •ํ™•๋„๊ฐ€ ๋–จ์–ด์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ฒฐํ•จ์ด ๋ฏธ์„ธํ•˜๊ฑฐ๋‚˜ ์ œํ’ˆ ๊ฐ„ ์œ ์‚ฌ์„ฑ์ด ๋งค์šฐ ๋†’์€ ๊ฒฝ์šฐ, ๋ชจ๋ธ์ด ๊ฒฐํ•จ์„ ์ •ํ™•ํ•˜๊ฒŒ ํƒ์ง€ํ•˜์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•™์Šต ์„ธ๋ถ€์‚ฌํ•ญ

Hardware

  • CPU: Intel Core i9-13900K (24 Cores, 32 Threads)
  • RAM: 64GB DDR5
  • GPU: NVIDIA RTX 4090Ti 24GB
  • Storage: 1TB NVMe SSD + 2TB HDD
  • Operating System: Windows 11 pro

๋ฐ์ดํ„ฐ์…‹ ์ •๋ณด

์ด ๋ชจ๋ธ์€ ์‹œ๊ณ„์—ด ์žฌ๊ณ  ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ๋ฉ๋‹ˆ๋‹ค. ์ด ๋ฐ์ดํ„ฐ๋Š” ์žฌ๊ณ  ์ˆ˜์ค€, ๋‚ ์งœ ๋ฐ ๊ธฐํƒ€ ๊ด€๋ จ ํŠน์„ฑ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ๋Š” Conv1D์™€ BiLSTM ๋ ˆ์ด์–ด์— ์ ํ•ฉํ•˜๋„๋ก MinMax ์Šค์ผ€์ผ๋ง์„ ์‚ฌ์šฉํ•˜์—ฌ ์ „์ฒ˜๋ฆฌ๋˜๊ณ  ์ •๊ทœํ™”๋ฉ๋‹ˆ๋‹ค.

image/png

  • Data sources: https://huggingface.co/datasets/quandao92/vision-inventory-prediction-data

  • Training size:

    • 1์ฐจ : Few-shot learning with anomaly (10ea), good (4ea)
    • 2์ฐจ : Few-shot learning with anomaly (10ea), good (10ea)
    • 3์ฐจ : Few-shot learning with anomaly (10ea), good (110ea)
  • Time-step: 5์ดˆ ์ด๋‚ด

  • Data Processing Techniques:

    • normalization: description: "์ด๋ฏธ์ง€ ํ”ฝ์…€ ๊ฐ’์„ ํ‰๊ท  ๋ฐ ํ‘œ์ค€ํŽธ์ฐจ๋กœ ํ‘œ์ค€ํ™”" method: "'Normalize' from 'torchvision.transforms'"
    • max_resize: description: "์ด๋ฏธ์ง€์˜ ์ตœ๋Œ€ ํฌ๊ธฐ๋ฅผ ์œ ์ง€ํ•˜๋ฉฐ, ๋น„์œจ์„ ๋งž์ถ”๊ณ  ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•˜์—ฌ ํฌ๊ธฐ ์กฐ์ •" method: "Custom 'ResizeMaxSize' class"
    • random_resized_crop: description: "ํ›ˆ๋ จ ์ค‘์— ์ด๋ฏธ์ง€๋ฅผ ๋žœ๋ค์œผ๋กœ ์ž๋ฅด๊ณ  ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜์—ฌ ๋ณ€ํ˜•์„ ์ถ”๊ฐ€" method: "'RandomResizedCrop' from 'torchvision.transforms'"
    • resize: description: "๋ชจ๋ธ ์ž…๋ ฅ์— ๋งž๊ฒŒ ์ด๋ฏธ์ง€๋ฅผ ๊ณ ์ •๋œ ํฌ๊ธฐ๋กœ ์กฐ์ •" method: "'Resize' with BICUBIC interpolation"
    • center_crop: description: "์ด๋ฏธ์ง€์˜ ์ค‘์•™ ๋ถ€๋ถ„์„ ์ง€์ •๋œ ํฌ๊ธฐ๋กœ ์ž๋ฅด๊ธฐ" method: "'CenterCrop'"
    • to_tensor: description: "์ด๋ฏธ์ง€๋ฅผ PyTorch ํ…์„œ๋กœ ๋ณ€ํ™˜" method: "'ToTensor'"
    • augmentation (optional): description: "๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์„ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ๋žœ๋ค ๋ณ€ํ™˜ ์ ์šฉ, 'AugmentationCfg'๋กœ ์„ค์ • ๊ฐ€๋Šฅ" method: "Uses 'timm' library if specified"

AD-CLIP Model Architecture

image/png

  • model:
    • input_layer:
      • image_size: [640, 640, 3] # ํ‘œ์ค€ ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ
    • backbone:
      • name: CLIP (ViT-B-32) # CLIP ๋ชจ๋ธ์˜ ๋น„์ „ ํŠธ๋žœ์Šคํฌ๋จธ๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
      • filters: [32, 64, 128, 256, 512] # ๋น„์ „ ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๊ฐ ๋ ˆ์ด์–ด ํ•„ํ„ฐ ํฌ๊ธฐ
    • neck:
      • name: Anomaly Detection Module # ๊ฒฐํ•จ ํƒ์ง€๋ฅผ ์œ„ํ•œ ์ถ”๊ฐ€ ๋ชจ๋“ˆ
      • method: Contrastive Learning # CLIP ๋ชจ๋ธ์˜ ํŠน์ง•์„ ์‚ฌ์šฉํ•œ ๋Œ€์กฐ ํ•™์Šต ๊ธฐ๋ฒ•
    • head:
      • name: Anomaly Detection Head # ๊ฒฐํ•จ ํƒ์ง€๋ฅผ ์œ„ํ•œ ์ตœ์ข… ์ถœ๋ ฅ ๋ ˆ์ด์–ด
      • outputs:
        • anomaly_score: 1 # ์ด์ƒ ํƒ์ง€ ์ ์ˆ˜ (๋น„์ •์ƒ/์ •์ƒ ๊ตฌ๋ถ„)
        • class_probabilities: N # ๊ฐ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ํ™•๋ฅ  (๊ฒฐํ•จ ์—ฌ๋ถ€)

Optimizer and Loss Function

  • training:
    • optimizer:
      • name: AdamW # AdamW ์˜ตํ‹ฐ๋งˆ์ด์ € (๊ฐ€์ค‘์น˜ ๊ฐ์‡  ํฌํ•จ)
      • lr: 0.0001 # ํ•™์Šต๋ฅ 
    • loss:
      • classification_loss: 1.0 # ๋ถ„๋ฅ˜ ์†์‹ค (๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ)
      • anomaly_loss: 1.0 # ๊ฒฐํ•จ ํƒ์ง€ ์†์‹ค (์ด์ƒ ํƒ์ง€ ๋ชจ๋ธ์— ๋Œ€ํ•œ ์†์‹ค)
      • contrastive_loss: 1.0 # ๋Œ€์กฐ ํ•™์Šต ์†์‹ค (์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ์†์‹ค)

Metrics

  • metrics:
    • Precision # ์ •๋ฐ€๋„ (Precision)
    • Recall # ์žฌํ˜„์œจ (Recall)
    • mAP # ํ‰๊ท  ์ •๋ฐ€๋„ (Mean Average Precision)
    • F1-Score # F1-์ ์ˆ˜ (๊ท ํ˜• ์žกํžŒ ํ‰๊ฐ€ ์ง€ํ‘œ)

Training Parameters

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •

  • Learning Rate: 0.001.
  • Batch Size: 8.
  • Epochs: 200.

Pre-trained CLIP model

Model Download
ViT-B/32 download
ViT-B/16 download
ViT-L/14 download
ViT-L/14@336px download

Evaluation Parameters

  • F1-score: 95%์ด์ƒ.

ํ•™์Šต ์„ฑ๋Šฅ ๋ฐ ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ

  • ํ•™์Šต์„ฑ๋Šฅ ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ๊ณผ ๊ทธ๋ž˜ํ”„: image/png

  • ํ•™์Šต ๊ฒฐ๊ณผํ‘œ: image/png

  • ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ:

    Anomaly Product

    Normal Product

image/png

์„ค์น˜ ๋ฐ ์‹คํ–‰ ๊ฐ€์ด๋ผ์ธ

์ด ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด Python๊ณผ ํ•จ๊ป˜ ๋‹ค์Œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

  • ftfy==6.2.0: ํ…์ŠคํŠธ ์ •๊ทœํ™” ๋ฐ ์ธ์ฝ”๋”ฉ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • matplotlib==3.9.0: ๋ฐ์ดํ„ฐ ์‹œ๊ฐํ™” ๋ฐ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • numpy==1.24.3: ์ˆ˜์น˜ ์—ฐ์‚ฐ์„ ์œ„ํ•œ ํ•ต์‹ฌ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • opencv_python==4.9.0.80: ์ด๋ฏธ์ง€ ๋ฐ ๋น„๋””์˜ค ์ฒ˜๋ฆฌ์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • pandas==2.2.2: ๋ฐ์ดํ„ฐ ๋ถ„์„ ๋ฐ ์กฐ์ž‘์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • Pillow==10.3.0: ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฒ˜๋ฆฌ ๋ฐ ๋ณ€ํ™˜์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • PyQt5==5.15.10: GUI ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ฐœ๋ฐœ์„ ์œ„ํ•œ ํ”„๋ ˆ์ž„์›Œํฌ.
  • PyQt5_sip==12.13.0: PyQt5์™€ Python ๊ฐ„์˜ ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์ œ๊ณตํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • regex==2024.5.15: ์ •๊ทœ ํ‘œํ˜„์‹ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • scikit_learn==1.2.2: ๊ธฐ๊ณ„ ํ•™์Šต ๋ฐ ๋ฐ์ดํ„ฐ ๋ถ„์„์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • scipy==1.9.1: ๊ณผํ•™ ๋ฐ ๊ธฐ์ˆ  ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • setuptools==59.5.0: Python ํŒจํ‚ค์ง€ ๋ฐฐํฌ ๋ฐ ์„ค์น˜๋ฅผ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • scikit-image: ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๋ฐ ๋ถ„์„์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • tabulate==0.9.0: ํ‘œ ํ˜•ํƒœ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • thop==0.1.1.post2209072238: PyTorch ๋ชจ๋ธ์˜ FLOP ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋„๊ตฌ.
  • timm==0.6.13: ๋‹ค์–‘ํ•œ ์ตœ์‹  ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ์ œ๊ณตํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • torch==2.0.0: PyTorch ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ.
  • torchvision==0.15.1: ์ปดํ“จํ„ฐ ๋น„์ „ ์ž‘์—…์„ ์œ„ํ•œ PyTorch ํ™•์žฅ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • tqdm==4.65.0: ์ง„ํ–‰ ์ƒํ™ฉ์„ ์‹œ๊ฐ์ ์œผ๋กœ ํ‘œ์‹œํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.
  • pyautogui: GUI ์ž๋™ํ™”๋ฅผ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ.

๋ชจ๋ธ ์‹คํ–‰ ๋‹จ๊ณ„:

โœ… Prompt generating

  training_lib/prompt_ensemble.py

๐Ÿ‘ Prompts Built in the Code

  1. Normal Prompt: '["{ }"]'
    โ†’ Normal Prompt Example: "object"
  2. Anomaly Prompt: '["damaged { }"]'
    โ†’ Anomaly Prompt Example: "damaged object"

๐Ÿ‘ Construction Process

  1. 'prompts_pos (Normal)': Combines the class name with the normal template
  2. 'prompts_neg (Anomaly)': Combines the class name with the anomaly template

โœ… Initial setting for training

  • Define the path to the training dataset and model checkpoint saving
parser.add_argument("--train_data_path", type=str, default="./data/", help="train dataset path")
parser.add_argument("--dataset", type=str, default='smoke_cloud', help="train dataset name")
parser.add_argument("--save_path", type=str, default='./checkpoint/', help='path to save results')

โœ… Hyper parameters setting

  • Set the depth parameter: depth of the embedding learned during prompt training. This affects the model's ability to learn complex features from the data
parser.add_argument("--depth", type=int, default=9, help="image size")
  • Define the size of input images used for training (pixel)
parser.add_argument("--image_size", type=int, default=518, help="image size")
  • Setting parameters for training
parser.add_argument("--epoch", type=int, default=500, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
  • Size/depth parameter for the DPAM (Deep Prompt Attention Mechanism)
parser.add_argument("--dpam", type=int, default=20, help="dpam size")

1. ViT-B/32 and ViT-B/16: --dpam should be around 10-13
2. ViT-L/14 and ViT-L/14@336px: --dpam should be around 20-24
โ†’ DPAM is used to refine and enhance specific layers of a model, particularly in Vision Transformers (ViT).
โ†’ Helps the model focus on important features within each layer through an attention mechanism
โ†’ Layers: DPAM is applied across multiple layers, allowing deeper and more detailed feature extraction
โ†’ Number of layers DPAM influences is adjustable (--dpam), controlling how much of the model is fine-tuned.
โ†’ If you want to refine the entire model, you can set --dpam to the number of layers in the model (e.g., 12 for ViT-B and 24 for ViT-L).
โ†’  If you want to focus only on the final layers (where the model usually learns complex features), you can choose fewer DPAM layers.

โœ… Test process

๐Ÿ‘ Load pre-trained and Fine tuned (Checkpoints) models

  1. Pre-trained mode (./pre-trained model/):
โ†’ Contains the pre-trained model (ViT-B, ViT-L,....)
โ†’ Used as the starting point for training the CLIP model
โ†’ Pre-trained model helps speed up and improve training by leveraging previously learned features
  1. Fine-tuned models (./checkpoint/):
โ†’ "epoch_N.pth" files in this folder store the model's states during the fine-tuning process.
โ†’ Each ".pth" file represents a version of the model fine-tuned from the pre-trained model
โ†’ These checkpoints can be used to resume fine-tuning, evaluate the model at different stages, or select the best-performing version