YiHuan commited on
Commit
858c0d4
1 Parent(s): fade163

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -49
app.py CHANGED
@@ -1,71 +1,82 @@
1
- import chromadb
2
- from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
3
  import json
4
- from ast import literal_eval
5
- from chromadb.config import Settings
6
- from paddlenlp import Taskflow
7
  import requests
8
- from io import BytesIO
9
- from PIL import Image
10
  import gradio as gr
 
 
 
11
 
12
- vision_language=Taskflow("feature_extraction", model='PaddlePaddle/ernie_vil-2.0-base-zh')
13
-
14
- def getImageTestFeture(content):
15
- if content.startswith("http"):
16
- response = requests.get(content)
17
- x=BytesIO(response.content)
18
- f_embeds = vision_language(Image.open(x))
19
- else:
20
- f_embeds = vision_language(content)
21
- text_features = f_embeds["features"][0]
22
- return text_features
23
 
24
- class MyEmbeddingFunction(EmbeddingFunction):
25
- def __call__(self, texts: Documents) -> Embeddings:
26
- qr=[]
27
- for doc in texts:
28
- text_embeds = getImageTestFeture(doc)
29
- #print(len(text_features))
30
- bedx=text_embeds.tolist()
31
- qr.append(bedx)
32
- return qr
33
 
34
- client = chromadb.Client(Settings(
35
- chroma_db_impl="duckdb+parquet",
36
- persist_directory="x/" # Optional, defaults to .chromadb/ in the current directory
37
- ))
 
 
 
 
 
 
 
38
 
39
- collection = client.get_or_create_collection(name="pics", metadata={"hnsw:space": "cosine"}, embedding_function=MyEmbeddingFunction())
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
41
  def queryimgage(text):
42
  html="<table border='1'>\
43
  <tr>\
44
  <th>img</th>\
45
  <th>score</th>\
46
  </tr>"
47
- atext=[]
48
- atext.append(text)
49
- results = collection.query(
50
- query_texts=atext,
51
- n_results=20,
52
- )
53
- ids=results['ids'][0]
54
- documents=results['documents'][0]
55
- distances=results['distances'][0]
56
- xcount=len(ids)
57
- for i in range(xcount):
58
- #print("id:%s,url:%s,score:%s"%(ids[i],documents[i],distances[i]))
 
 
 
 
 
 
59
  html=html +"<tr>\
60
- <td><img src='"+documents[xcount-1-i]+"' width=640></td>\
61
- <td>"+ str(distances[xcount-1-i])+"</td>"
62
  html=html+"</table>"
63
  return html
64
 
 
 
 
 
 
 
65
  demo = gr.Interface(
66
- queryimgage,
67
- gr.Textbox(placeholder="请输入文本"),
68
- [ "html"]
 
69
  )
70
 
71
  demo.launch()
 
1
+
 
2
  import json
 
 
 
3
  import requests
 
 
4
  import gradio as gr
5
+ import typesense
6
+ from urllib.parse import quote
7
+ import os
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def getWordVec(content):
12
+ content=quote(content,'utf-8')
13
+ #print(content)
14
+ xurl=os.getenv("emburl")
15
+ print(xurl)
16
+ url=xurl +content
17
+ #print(url)
18
+ response = requests.get(url)
19
+ jsonar=json.loads(response.text).get("embed")
20
+ #print(len(jsonar))
21
+ return jsonar
22
 
23
+ typesenseserver=os.getenv("typesenseserver")
24
+ typesenseport=os.getenv("typesenseport")
25
+ typesensekey=os.getenv("typesensekey")
26
+ typesensecolname=os.getenv("typesensecolname")
27
 
28
+ confignode={}
29
+ confignode['host']=typesenseserver
30
+ confignode['port']=typesenseport
31
+ confignode['protocol']='http'
32
+ nodes=[]
33
+ nodes.append(confignode)
34
+ nodeconfig={}
35
+ nodeconfig["nodes"]=nodes
36
+ nodeconfig["api_key"]=typesensekey
37
+ print(nodeconfig)
38
+ client = typesense.Client(nodeconfig)
39
  def queryimgage(text):
40
  html="<table border='1'>\
41
  <tr>\
42
  <th>img</th>\
43
  <th>score</th>\
44
  </tr>"
45
+ info=getWordVec(text)
46
+ search_requests = {
47
+ 'searches': [
48
+ {
49
+ 'collection': typesensecolname,
50
+ 'q' : '*',
51
+ 'per_page': 20,
52
+ 'exclude_fields' : 'my_vector',
53
+ 'vector_query': 'my_vector:(['+",".join(str(x) for x in info )+'], k:1000)'
54
+ }
55
+ ]
56
+ }
57
+ common_search_params = {}
58
+ res=client.multi_search.perform(search_requests, common_search_params)
59
+ result=res['results'][0]['hits']
60
+ for resultinfo in result:
61
+ documents=resultinfo['document']
62
+ score=(2-resultinfo['vector_distance'])/2
63
  html=html +"<tr>\
64
+ <td><img src='"+documents['imageurl']+"' width=640 height=600></td>\
65
+ <td>"+ str(score)+"</td>"
66
  html=html+"</table>"
67
  return html
68
 
69
+ def getNumtip():
70
+ num=client.collections[typesensecolname].retrieve()["num_documents"]
71
+ numtip="图片数:" + str(num)
72
+ return numtip
73
+
74
+
75
  demo = gr.Interface(
76
+ fn=queryimgage,
77
+ inputs=gr.Textbox(placeholder="请输入文本"),
78
+ outputs=[ "html"],
79
+ article=getNumtip()
80
  )
81
 
82
  demo.launch()