etrotta commited on
Commit
63a1db6
·
1 Parent(s): b6be18b

Change the vector database used and embed the embeddings within the program

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .kanjidb/**/* filter=lfs diff=lfs merge=lfs -text
37
+ *.lance filter=lfs diff=lfs merge=lfs -text
.kanjidb/kanji.lance/_latest.manifest ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d870cc02d1237f025526565eba9054fdecb7e5f29f96230707b6418a1594f8e7
3
+ size 279
.kanjidb/kanji.lance/_transactions/0-199d3b3b-6378-4c03-9465-cad215d61adf.txn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb34cb00ef08393e705a75aa78c4068c8da6c5385fc118afd506f18e5c956594
3
+ size 165
.kanjidb/kanji.lance/_transactions/1-a2c30a19-a8de-47a2-ae7b-08ce8f374e71.txn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9963f78e3fbb2f07e662ecb999ddc0e6f276649ae8cc65d4a9159576fbf79e17
3
+ size 100
.kanjidb/kanji.lance/_versions/1.manifest ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abfe73ec38efbad945e6146850736e4572d7bd364b4361c2c4c595be35cf9cc
3
+ size 222
.kanjidb/kanji.lance/_versions/2.manifest ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d870cc02d1237f025526565eba9054fdecb7e5f29f96230707b6418a1594f8e7
3
+ size 279
.kanjidb/kanji.lance/data/866ef445-f6ad-45be-8ff1-620646e5f41f.lance ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41f3f57790cfa577b99f2283d5a470914aed2082449f42fcb37d9b62634b4c6f
3
+ size 238451553
README.md CHANGED
@@ -4,9 +4,9 @@ emoji: 👀
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.18.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.31.4
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py CHANGED
@@ -6,16 +6,15 @@ from config import (
6
  article,
7
  )
8
  from encode import get_embeddings
9
- from database import search_vector, format_search_results
10
 
11
  def search_images(values):
12
  image = Image.new("RGBA", values["composite"].size, (255, 255, 255, 255))
13
  image.paste(values["composite"], mask=values["composite"])
14
  embedding = get_embeddings([image])[0]
15
  results = search_vector(embedding, limit=100)
16
- formatted = format_search_results(results)
17
 
18
- _deduplicated = '\t'.join(dict.fromkeys(result.kanji for result in formatted))
19
  # TODO Format the results better
20
  # Huge boxes using the right font for each of them?
21
 
@@ -27,7 +26,7 @@ input_image = gr.ImageEditor(
27
  show_label=False,
28
  type="pil",
29
  brush=gr.Brush(
30
- default_size=3,
31
  color_mode="fixed",
32
  colors=["#000000", "#ffffff"],
33
  ),
 
6
  article,
7
  )
8
  from encode import get_embeddings
9
+ from database import search_vector
10
 
11
  def search_images(values):
12
  image = Image.new("RGBA", values["composite"].size, (255, 255, 255, 255))
13
  image.paste(values["composite"], mask=values["composite"])
14
  embedding = get_embeddings([image])[0]
15
  results = search_vector(embedding, limit=100)
 
16
 
17
+ _deduplicated = '\t'.join(dict.fromkeys(result.kanji for result in results))
18
  # TODO Format the results better
19
  # Huge boxes using the right font for each of them?
20
 
 
26
  show_label=False,
27
  type="pil",
28
  brush=gr.Brush(
29
+ default_size=16,
30
  color_mode="fixed",
31
  colors=["#000000", "#ffffff"],
32
  ),
config.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
 
3
- qdrant_location = os.getenv('QDRANT_URL', "localhost")
4
- qdrant_api_key = os.getenv('QDRANT_API_KEY')
5
 
6
  description = """This is a Kanji image search demo. Draw or upload an image of an individual Kanji character."""
7
 
@@ -16,7 +15,7 @@ The results is sorted by estimated distance from the input, but will rarely give
16
 
17
  ### About this project
18
 
19
- It uses the "kha-white/manga-ocr-base" Vision Transformer Encoder model to create embeddings, then uses a vector database (qdrant) to find similar characters.
20
 
21
  You can find the code used to create the embeddings as well as more information in https://github.com/etrotta/kanji_lookup
22
 
 
1
  import os
2
 
3
+ lancedb_location = os.getenv('DATABASE_FILE', ".kanjidb")
 
4
 
5
  description = """This is a Kanji image search demo. Draw or upload an image of an individual Kanji character."""
6
 
 
15
 
16
  ### About this project
17
 
18
+ It uses the "kha-white/manga-ocr-base" Vision Transformer Encoder model to create embeddings, then uses a vector database (lancedb) to find similar characters.
19
 
20
  You can find the code used to create the embeddings as well as more information in https://github.com/etrotta/kanji_lookup
21
 
database.py CHANGED
@@ -1,39 +1,27 @@
1
- import dataclasses
2
  import torch
3
- from qdrant_client import QdrantClient, models
 
 
 
4
 
5
- from config import qdrant_location, qdrant_api_key
6
 
7
- qdrant = QdrantClient(
8
- qdrant_location,
9
- api_key=qdrant_api_key,
10
- port=443,
11
- timeout=30,
12
- )
13
 
14
- def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]:
15
- hits = qdrant.search(
16
- collection_name="kanji",
17
- # query_vector=query_vector,
18
- query_vector=query_vector.numpy(),
19
- limit=limit,
20
- with_payload=True,
21
- )
22
- return hits
23
-
24
- @dataclasses.dataclass
25
- class SearchResult:
26
  kanji: str
27
- font: str
28
- score: float
29
 
30
- def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]:
31
- formatted = []
32
- for point in hits:
33
- kanji, font = point.payload["kanji"], point.payload["font"]
34
- formatted.append(SearchResult(
35
- kanji = kanji,
36
- font = font,
37
- score = point.score,
38
- ))
39
- return formatted
 
 
 
 
1
  import torch
2
+ import lancedb
3
+ from lancedb.pydantic import LanceModel
4
+ import pydantic
5
+ # import time
6
 
7
+ from config import lancedb_location
8
 
9
+ db = lancedb.connect(lancedb_location)
10
+ table = db.open_table("kanji")
 
 
 
 
11
 
12
+ class SearchResult(LanceModel):
 
 
 
 
 
 
 
 
 
 
 
13
  kanji: str
14
+ distance: float = pydantic.Field(validation_alias=pydantic.AliasChoices('distance', '_distance'))
 
15
 
16
+ def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[SearchResult]:
17
+ # start = time.perf_counter()
18
+ results = (
19
+ table
20
+ .search(query_vector.numpy(), vector_column_name="vector", query_type="vector")
21
+ .limit(limit)
22
+ # .to_pydantic(SearchResult) # type: ignore
23
+ .to_list()
24
+ )
25
+ # end = time.perf_counter()
26
+ # print(f"Searched in {end - start:.3f}")
27
+ return [SearchResult.model_validate(result) for result in results]
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ