EasyDetect / pipeline /tool_execute.py
sunnychenxiwang's picture
all
55d9644
raw
history blame
No virus
5.34 kB
import yaml
import json
import shortuuid
import base64
from PIL import Image
import os
from tqdm import tqdm
from PIL import Image
from openai import OpenAI
client = OpenAI(base_url="https://oneapi.xty.app/v1",api_key="sk-jD8DeGdJKrdOxpiQ5bD4845bB53346C3A0E9Ed479bE08676")
# import sys
# sys.path.append("/home/wcx/wcx/EasyDetect/tool")
from pipeline.tool.object_detetction_model import *
from pipeline.tool.google_serper import *
def get_openai_reply(image_path, text):
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
img = encode_image(image_path)
content = [
{"type": "text", "text": text},
{"type": "image_url","image_url": f"data:image/jpeg;base64,{img}"},
]
messages=[
{
"role": "user",
"content": content,
}
]
resp = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=messages,
max_tokens=1024,
)
return resp.choices[0].message.content
class Tool:
def __init__(self):
self.detector = GroundingDINO()
self.search = GoogleSerperAPIWrapper()
def execute(self, image_path, new_path, objects, attribute_list, scenetext_list, fact_list):
use_text_rec = False
use_attribute = False
for key in scenetext_list:
if scenetext_list[key][0] != "none":
use_text_rec = True
text_res = None
if use_text_rec:
text_res = self.detector.execute(image_path=image_path,content="word.number",new_path=new_path,use_text_rec=True)
object_res = self.detector.execute(image_path=image_path,content=objects,new_path=new_path,use_text_rec=False)
queries = ""
if use_attribute:
cnt = 1
for key in attribute_list:
if attribute_list[key][0] != "none":
for query in attribute_list[key]:
queries += str(cnt) + "." + query + "\n"
cnt += 1
# print(queries)
if queries == "":
attribue_res = "none information"
else:
attribue_res = get_openai_reply(image_path, queries)
# print(attribue_res)
fact_res = ""
cnt = 1
for key in fact_list:
if fact_list[key][0] != "none":
evidences = self.search.execute(input="", content=str(fact_list[key]))
for evidence in evidences:
fact_res += str(cnt) + "." + evidence + "\n"
cnt += 1
if fact_res == "":
fact_res = "none information"
return object_res, attribue_res, text_res, fact_res
# if __name__ == '__main__':
# tool = Tool()
# extractor = Extractor(model="gpt-4-1106-preview", config_path= "/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/object_extract.yaml", type="image-to-text")
# # "/home/wcx/wcx/LVLMHall-test/text-to-image/labeled.json"
# query = Query(config_path="/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/query.yaml",type="image-to-text")
# path = "/home/wcx/wcx/LVLMHall-test/MSCOCO/caption/labeled/minigpt4-100-cx-revise-v1.json"
# with open(path, "r", encoding="utf-8") as f:
# for idx, line in tqdm(enumerate(f.readlines()), total=250):
# # if idx < 189:
# # continue
# data = data2
# #data = json.loads(line)
# image_path = data["image_path"]#"/newdisk3/wcx" + data["image_path"]
# claim_list = ""
# cnt = 1
# for seg in data["segments"]:
# for cla in seg["claims"]:
# claim_list += "claim" + str(cnt) + ": " + cla["claim"] + "\n"
# cnt += 1
# object_list, objects = extractor.get_response(claims=claim_list)
# print("pre:" + objects)
# attribute_list, scenetext_list, fact_list, objects = query.get_response(claim_list, objects, object_list)
# print("after:" + objects)
# print(object_list)
# print(attribute_list)
# print(scenetext_list)
# print(fact_list)
# object_res, attribue_res, text_res, fact_res = tool.execute(image_path=image_path,
# new_path="/newdisk3/wcx/MLLM/image-to-text/minigpt4/",
# attribute_list=attribute_list,
# scenetext_list=scenetext_list,
# fact_list=fact_list,
# objects=objects)
# # print(object_res)
# # print(attribue_res)
# # print(text_res)
# #print(fact_res[:50])
# print("=============================")
# break