Spaces:
Sleeping
Sleeping
vardaan123
commited on
Commit
•
e75805e
1
Parent(s):
93d2b0f
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,17 @@ from tqdm import tqdm
|
|
12 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
13 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Load necessary data and initialize the model
|
16 |
entity2id = json.load(open('entity2id_subtree.json', 'r'))
|
17 |
id2entity = {v: k for k, v in entity2id.items()}
|
|
|
12 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
13 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
14 |
|
15 |
+
def generate_target_list(data, entity2id):
|
16 |
+
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
|
17 |
+
sub = list(sub['t'])
|
18 |
+
categories = []
|
19 |
+
for item in tqdm(sub):
|
20 |
+
if entity2id[str(int(float(item)))] not in categories:
|
21 |
+
categories.append(entity2id[str(int(float(item)))])
|
22 |
+
# print('categories = {}'.format(categories))
|
23 |
+
# print("No. of target categories = {}".format(len(categories)))
|
24 |
+
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
|
25 |
+
|
26 |
# Load necessary data and initialize the model
|
27 |
entity2id = json.load(open('entity2id_subtree.json', 'r'))
|
28 |
id2entity = {v: k for k, v in entity2id.items()}
|