tcm03
commited on
Commit
•
b6ff56b
1
Parent(s):
abf214e
Add model checkpoint with Git LFS
Browse files- .gitattributes +1 -0
- pipeline.py → handler.py +1 -1
- model/tsbir_model_final.pt +3 -0
- model/tsbir_model_final.pt:Zone.Identifier +4 -0
- test.py +27 -0
.gitattributes
CHANGED
@@ -1 +1,2 @@
|
|
1 |
code/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
code/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
|
2 |
+
model/tsbir_model_final.pt filter=lfs diff=lfs merge=lfs -text
|
pipeline.py → handler.py
RENAMED
@@ -13,7 +13,7 @@ from clip.clip import _transform, tokenize
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
-
class
|
17 |
def __init__(self, path: str = ""):
|
18 |
"""
|
19 |
Initialize the pipeline by loading the model.
|
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
+
class EndpointHandler:
|
17 |
def __init__(self, path: str = ""):
|
18 |
"""
|
19 |
Initialize the pipeline by loading the model.
|
model/tsbir_model_final.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bc305f6a75e448861ab8b02ebbed143d2ba09723c158646bd53f13a86934b3f
|
3 |
+
size 2713984713
|
model/tsbir_model_final.pt:Zone.Identifier
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=https://github.com/janesjanes/tsbir?tab=readme-ov-file
|
4 |
+
HostUrl=https://patsorn.me/projects/tsbir/data/tsbir_model_final.pt
|
test.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
import base64
|
3 |
+
|
4 |
+
# Helper function to encode an image to Base64
|
5 |
+
def encode_image_to_base64(image_path):
|
6 |
+
with open(image_path, "rb") as image_file:
|
7 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
8 |
+
|
9 |
+
# Initialize the handler
|
10 |
+
handler = EndpointHandler(path=".")
|
11 |
+
|
12 |
+
# Prepare sample inputs
|
13 |
+
image_path = "path_to_your_sketch_image.jpg" # Replace with your image path
|
14 |
+
base64_image = encode_image_to_base64(image_path)
|
15 |
+
text_query = "A pink flower"
|
16 |
+
|
17 |
+
# Create payload
|
18 |
+
payload = {
|
19 |
+
"image": base64_image,
|
20 |
+
"text": text_query
|
21 |
+
}
|
22 |
+
|
23 |
+
# Run the handler
|
24 |
+
response = handler(payload)
|
25 |
+
|
26 |
+
# Show results
|
27 |
+
print("Fused Embedding:", response)
|