Orpheous1 commited on
Commit
9d9aad0
1 Parent(s): 56ec0e7
Files changed (4) hide show
  1. app.py +8 -1
  2. flagged/Input/0.png +0 -0
  3. flagged/Output/0.png +0 -0
  4. flagged/log.csv +2 -0
app.py CHANGED
@@ -12,18 +12,25 @@ import torch
12
 
13
  # Load Vision Transformer
14
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
 
15
  vit = ViTForImageClassification.from_pretrained(hf_model)
 
16
  vit.eval()
 
17
 
18
  # Load Feature Extractor
19
  feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
 
20
  feature_extractor = UnNest(feature_extractor)
 
21
 
22
  # Load Vision DiffMask
23
  diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
24
  diffmask.set_vision_transformer(vit)
25
  diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
26
- diffmask_imagenet.set_vision_transformer(vit)
 
 
27
 
28
  # Define mask plotting functions
29
  def draw_mask(image, mask):
 
12
 
13
  # Load Vision Transformer
14
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
15
+ hf_model_imagenet = "google/vit-base-patch16-224"
16
  vit = ViTForImageClassification.from_pretrained(hf_model)
17
+ vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet)
18
  vit.eval()
19
+ vit_imagenet.eval()
20
 
21
  # Load Feature Extractor
22
  feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
23
+ feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt")
24
  feature_extractor = UnNest(feature_extractor)
25
+ feature_extractor_imagenet = UnNest(feature_extractor_imagenet)
26
 
27
  # Load Vision DiffMask
28
  diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
29
  diffmask.set_vision_transformer(vit)
30
  diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
31
+ diffmask_imagenet.set_vision_transformer(vit_imagenet)
32
+ diffmask.eval()
33
+ diffmask_imagenet.eval()
34
 
35
  # Define mask plotting functions
36
  def draw_mask(image, mask):
flagged/Input/0.png ADDED
flagged/Output/0.png ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 'Input','model_name','Output','Prediction','flag','username','timestamp'
2
+ 'Input/0.png','DiffMask-ImageNet','Output/0.png','airplane','','','2022-06-30 18:09:57.950187'