tcm03 commited on
Commit
b6ff56b
1 Parent(s): abf214e

Add model checkpoint with Git LFS

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  code/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
 
 
1
  code/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
2
+ model/tsbir_model_final.pt filter=lfs diff=lfs merge=lfs -text
pipeline.py → handler.py RENAMED
@@ -13,7 +13,7 @@ from clip.clip import _transform, tokenize
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- class PreTrainedPipeline:
17
  def __init__(self, path: str = ""):
18
  """
19
  Initialize the pipeline by loading the model.
 
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ class EndpointHandler:
17
  def __init__(self, path: str = ""):
18
  """
19
  Initialize the pipeline by loading the model.
model/tsbir_model_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bc305f6a75e448861ab8b02ebbed143d2ba09723c158646bd53f13a86934b3f
3
+ size 2713984713
model/tsbir_model_final.pt:Zone.Identifier ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=https://github.com/janesjanes/tsbir?tab=readme-ov-file
4
+ HostUrl=https://patsorn.me/projects/tsbir/data/tsbir_model_final.pt
test.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ import base64
3
+
4
+ # Helper function to encode an image to Base64
5
+ def encode_image_to_base64(image_path):
6
+ with open(image_path, "rb") as image_file:
7
+ return base64.b64encode(image_file.read()).decode("utf-8")
8
+
9
+ # Initialize the handler
10
+ handler = EndpointHandler(path=".")
11
+
12
+ # Prepare sample inputs
13
+ image_path = "path_to_your_sketch_image.jpg" # Replace with your image path
14
+ base64_image = encode_image_to_base64(image_path)
15
+ text_query = "A pink flower"
16
+
17
+ # Create payload
18
+ payload = {
19
+ "image": base64_image,
20
+ "text": text_query
21
+ }
22
+
23
+ # Run the handler
24
+ response = handler(payload)
25
+
26
+ # Show results
27
+ print("Fused Embedding:", response)