vardaan123 commited on
Commit
e75805e
1 Parent(s): 93d2b0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -12,6 +12,17 @@ from tqdm import tqdm
12
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
13
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Load necessary data and initialize the model
16
  entity2id = json.load(open('entity2id_subtree.json', 'r'))
17
  id2entity = {v: k for k, v in entity2id.items()}
 
12
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
13
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
14
 
15
+ def generate_target_list(data, entity2id):
16
+ sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
17
+ sub = list(sub['t'])
18
+ categories = []
19
+ for item in tqdm(sub):
20
+ if entity2id[str(int(float(item)))] not in categories:
21
+ categories.append(entity2id[str(int(float(item)))])
22
+ # print('categories = {}'.format(categories))
23
+ # print("No. of target categories = {}".format(len(categories)))
24
+ return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
25
+
26
  # Load necessary data and initialize the model
27
  entity2id = json.load(open('entity2id_subtree.json', 'r'))
28
  id2entity = {v: k for k, v in entity2id.items()}