fffiloni commited on
Commit
5a972ab
1 Parent(s): 83b6ef3

Create misc.py

Browse files
Files changed (1) hide show
  1. utils/misc.py +122 -0
utils/misc.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ def get_prompt_templates():
5
+ prompt_templates = [
6
+ '{}.',
7
+ 'a photo of a {}.',
8
+ 'a bad photo of a {}.',
9
+ 'a photo of many {}.',
10
+ 'a sculpture of a {}.',
11
+ 'a photo of the hard to see {}.',
12
+ 'a low resolution photo of the {}.',
13
+ 'a rendering of a {}.',
14
+ 'graffiti of a {}.',
15
+ 'a bad photo of the {}.',
16
+ 'a cropped photo of the {}.',
17
+ 'a tattoo of a {}.',
18
+ 'the embroidered {}.',
19
+ 'a photo of a hard to see {}.',
20
+ 'a bright photo of a {}.',
21
+ 'a photo of a clean {}.',
22
+ 'a photo of a dirty {}.',
23
+ 'a dark photo of the {}.',
24
+ 'a drawing of a {}.',
25
+ 'a photo of my {}.',
26
+ 'the plastic {}.',
27
+ 'a photo of the cool {}.',
28
+ 'a close-up photo of a {}.',
29
+ 'a black and white photo of the {}.',
30
+ 'a painting of the {}.',
31
+ 'a painting of a {}.',
32
+ 'a pixelated photo of the {}.',
33
+ 'a sculpture of the {}.',
34
+ 'a bright photo of the {}.',
35
+ 'a cropped photo of a {}.',
36
+ 'a plastic {}.',
37
+ 'a photo of the dirty {}.',
38
+ 'a jpeg corrupted photo of a {}.',
39
+ 'a blurry photo of the {}.',
40
+ 'a photo of the {}.',
41
+ 'a good photo of the {}.',
42
+ 'a rendering of the {}.',
43
+ 'a {} in a video game.',
44
+ 'a photo of one {}.',
45
+ 'a doodle of a {}.',
46
+ 'a close-up photo of the {}.',
47
+ 'the origami {}.',
48
+ 'the {} in a video game.',
49
+ 'a sketch of a {}.',
50
+ 'a doodle of the {}.',
51
+ 'a origami {}.',
52
+ 'a low resolution photo of a {}.',
53
+ 'the toy {}.',
54
+ 'a rendition of the {}.',
55
+ 'a photo of the clean {}.',
56
+ 'a photo of a large {}.',
57
+ 'a rendition of a {}.',
58
+ 'a photo of a nice {}.',
59
+ 'a photo of a weird {}.',
60
+ 'a blurry photo of a {}.',
61
+ 'a cartoon {}.',
62
+ 'art of a {}.',
63
+ 'a sketch of the {}.',
64
+ 'a embroidered {}.',
65
+ 'a pixelated photo of a {}.',
66
+ 'itap of the {}.',
67
+ 'a jpeg corrupted photo of the {}.',
68
+ 'a good photo of a {}.',
69
+ 'a plushie {}.',
70
+ 'a photo of the nice {}.',
71
+ 'a photo of the small {}.',
72
+ 'a photo of the weird {}.',
73
+ 'the cartoon {}.',
74
+ 'art of the {}.',
75
+ 'a drawing of the {}.',
76
+ 'a photo of the large {}.',
77
+ 'a black and white photo of a {}.',
78
+ 'the plushie {}.',
79
+ 'a dark photo of a {}.',
80
+ 'itap of a {}.',
81
+ 'graffiti of the {}.',
82
+ 'a toy {}.',
83
+ 'itap of my {}.',
84
+ 'a photo of a cool {}.',
85
+ 'a photo of a small {}.',
86
+ 'a tattoo of the {}.',
87
+ ]
88
+ return prompt_templates
89
+
90
+
91
+ def prompt_engineering(classnames, topk=1, suffix='.'):
92
+ prompt_templates = get_prompt_templates()
93
+ temp_idx = np.random.randint(min(len(prompt_templates), topk))
94
+
95
+ if isinstance(classnames, list):
96
+ classname = random.choice(classnames)
97
+ else:
98
+ classname = classnames
99
+
100
+ return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))
101
+
102
+ class AverageMeter(object):
103
+ """Computes and stores the average and current value."""
104
+ def __init__(self):
105
+ self.reset()
106
+
107
+ def reset(self):
108
+ self.val = 0
109
+ self.avg = 0
110
+ self.sum = 0
111
+ self.count = 0
112
+
113
+ def update(self, val, n=1, decay=0):
114
+ self.val = val
115
+ if decay:
116
+ alpha = math.exp(-n / decay) # exponential decay over 100 updates
117
+ self.sum = alpha * self.sum + (1 - alpha) * val * n
118
+ self.count = alpha * self.count + (1 - alpha) * n
119
+ else:
120
+ self.sum += val * n
121
+ self.count += n
122
+ self.avg = self.sum / self.count