|
import xml.etree.ElementTree as ET
|
|
from elasticsearch import Elasticsearch, helpers
|
|
import torch
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
import numpy as np
|
|
|
|
from server.utils.database import get_db
|
|
from server.utils.model import get_clip_model
|
|
|
|
from server.models.database import DocumentModel
|
|
from server.crud.patent_data import search_data
|
|
|
|
|
|
clip_model, clip_processor = get_clip_model()
|
|
|
|
|
|
def search_data_from_database(query=None, image_path=None,top_k=5,db=get_db(), index_name="patents"):
|
|
if query:
|
|
inputs = clip_processor(text=[query], return_tensors="pt", padding=True, truncation=True)
|
|
with torch.no_grad():
|
|
embedding = clip_model.get_text_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
|
|
elif image_path:
|
|
from PIL import Image
|
|
image = Image.open(image_path).convert("RGB")
|
|
inputs = clip_processor(images=image, return_tensors="pt")
|
|
with torch.no_grad():
|
|
embedding = clip_model.get_image_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
|
|
else:
|
|
return []
|
|
|
|
return search_data(
|
|
embedding=embedding,
|
|
top_k=top_k,
|
|
db=db,
|
|
index_name=index_name
|
|
) |