Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
|
|
4 |
import clip
|
5 |
import os
|
6 |
from torch import nn
|
@@ -234,9 +235,9 @@ def inference(img,model_name):
|
|
234 |
model = ClipCaptionModel(prefix_length)
|
235 |
|
236 |
if model_name == "COCO":
|
237 |
-
model_path =
|
238 |
else:
|
239 |
-
model_path =
|
240 |
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
241 |
model = model.eval()
|
242 |
device = CUDA(0) if is_gpu else "cpu"
|
|
|
1 |
import os
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
conceptual_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
|
4 |
+
coco_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")
|
5 |
import clip
|
6 |
import os
|
7 |
from torch import nn
|
|
|
235 |
model = ClipCaptionModel(prefix_length)
|
236 |
|
237 |
if model_name == "COCO":
|
238 |
+
model_path = coco_weight
|
239 |
else:
|
240 |
+
model_path = conceptual_weight
|
241 |
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
242 |
model = model.eval()
|
243 |
device = CUDA(0) if is_gpu else "cpu"
|