Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
14 |
image_size = 384
|
15 |
|
16 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
17 |
-
std=[0.229, 0.224, 0.225
|
|
|
18 |
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
|
19 |
|
20 |
#######Tag2Text Model
|
@@ -41,7 +42,8 @@ def inference(raw_image, model_n , input_tag):
|
|
41 |
image = transform(raw_image).unsqueeze(0).to(device)
|
42 |
if model_n == 'Recognize Anything Model':
|
43 |
model = model_ram
|
44 |
-
|
|
|
45 |
return tags[0],tags_chinese[0], 'none'
|
46 |
else:
|
47 |
model = model_tag2text
|
|
|
14 |
image_size = 384
|
15 |
|
16 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
17 |
+
std=[0.229, 0.224, 0.225
|
18 |
+
])
|
19 |
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
|
20 |
|
21 |
#######Tag2Text Model
|
|
|
42 |
image = transform(raw_image).unsqueeze(0).to(device)
|
43 |
if model_n == 'Recognize Anything Model':
|
44 |
model = model_ram
|
45 |
+
with torch.no_grad():
|
46 |
+
tags, tags_chinese = model.generate_tag(image)
|
47 |
return tags[0],tags_chinese[0], 'none'
|
48 |
else:
|
49 |
model = model_tag2text
|