zwgao's picture
add file
3fdcc70
raw
history blame
5.93 kB
import codecs
import io
import os
import pickle
from pathlib import Path
from PIL import Image
import requests
import os
import sys
sys.path.append(os.getcwd())
from cllm.services.utils import get_bytes_value
from cllm.services.nlp.api import openai_chat_model
__ALL__ = [
"object_detection",
"image_classification",
"ocr",
"image_to_text",
"segment_objects",
]
HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
def setup(host="localhost", port=10049):
global HOST, PORT
HOST = host
PORT = port
def object_detection(image, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/object_detection"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.json()
def image_classification(image, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/image_classification"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.json()
def image_to_text(image, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/image_to_text"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.json()
def ocr(image, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/ocr"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.json()
def segment_objects(image, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/segment_objects"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
pickled = response.json()["data"]
output = pickle.loads(codecs.decode(pickled.encode(), "base64"))
for o in output:
stream = io.BytesIO()
o["mask"].save(stream, format="png")
stream.seek(0)
o["mask"] = stream.getvalue()
return output
def visual_grounding(image, query, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = rf"http://{host}:{port}/visual_grounding"
human_msg = f"""Your task is to extract the prompt from input. Here is examples:
Input:
find the regin of interest in the da9619_image.png: \"An elephant in right corner\"
Answer:
An elephant in right corner
Input:
locate \"A maintenance vehicle on a railway\" in the image
Answer:
A maintenance vehicle on a railway
Input:
use visual grounding method to detect the regin of interest in the 1ba6e2_image.png: The motorcycle with the rainbow flag"
Answer:
The motorcycle with the rainbow flag
Input:
for given image, find A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom."
Answer:
A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom
Input:
find the policeman on the motorcycle in the 851522_image.png"
Answer:
the policeman on the motorcycle
Input:
The legs of a zebra shown under the neck of another zebra.
Answer:
The legs of a zebra shown under the neck of another zebra.
Input:
{query}
Answer:
"""
extracted_prompt = openai_chat_model(human_msg)
files = {"image": get_bytes_value(image)}
data = {"query": extracted_prompt}
# image = Image.open(io.BytesIO(image)).convert("RGB")
response = requests.post(url, data=data, files=files)
return response.json()
def image_captioning(image, endpoint="llava", **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/{endpoint}"
data = None
if endpoint == "llava":
data = {"text": "Please describe the image in details."}
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files, data=data)
return response.content.decode("utf-8")
def segment_all(image: str | Path, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/segment_all"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.content
def set_image(image: str | Path, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/set_image"
files = {"image": (image, get_bytes_value(image))}
response = requests.post(url, files=files)
return response.content.decode()
def segment_by_mask(mask: str | Path, image_id: str, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/segment_by_mask"
data = {"image_id": image_id}
files = {"mask": (mask, get_bytes_value(mask))}
response = requests.post(url, files=files, data=data)
return response.content
def segment_by_points(points: list | tuple | str, image_id: str, **kwargs):
host = kwargs.get("host", HOST)
port = kwargs.get("port", PORT)
url = f"http://{host}:{port}/segment_by_points"
data = {"points": points, "image_id": image_id}
response = requests.post(url, data=data)
return response.content
def seg_by_mask(image, prompt_mask, **kwargs):
image_id = set_image(image)
mask = segment_by_mask(mask=prompt_mask, image_id=image_id)
return mask
def seg_by_points(image, prompt_points, **kwargs):
image_id = set_image(image)
mask = segment_by_points(points=prompt_points, image_id=image_id)
return mask