import os import cv2 import requests from PIL import Image import logging import torch from llm_service import get_llm from langchain_core.tools import tool,Tool from langchain_community.tools import DuckDuckGoSearchResults from langchain_groq import ChatGroq from utils import draw_panoptic_segmentation from tool_utils.clip_segmentation import CLIPSEG from tool_utils.object_extractor import create_object_extraction_chain from tool_utils.yolo_world import YoloWorld from tool_utils.image_metadata import image_brightness,variance_of_laplacian,get_signal_to_noise_ratio try: from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation except ImportError as err: logging.error("Import error :{}".format(err)) device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info("Loading Foundation Models") try: clipseg_model = CLIPSEG() except Exception as err : logging.error("Unable to clipseg model {}".format(err)) try: maskformer_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") except: logging.error("Unable to Maskformer model {}".format(err)) def get_groq_model(model_name = "gemma2-9b-it"): os.environ.get("GROQ_API_KEY") llm_groq = ChatGroq(model=model_name) return llm_groq @tool def panoptic_image_segemntation(image_path:str)->str: """ The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \ the objects present in the image . Use the tool in case user ask to create a panoptic segmentation. """ if image_path.startswith('https'): image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') else: image = Image.open(image_path).convert('RGB') maskformer_model.to(device) inputs = maskformer_processor(image, return_tensors="pt").to(device) with torch.no_grad(): outputs = maskformer_model(**inputs) prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] save_mask_path = draw_panoptic_segmentation(maskformer_model,prediction['segmentation'],prediction['segments_info']) labels = [] for segment in prediction['segments_info']: label_names = maskformer_model.config.id2label[segment['label_id']] print(label_names) labels.append(label_names) return 'Panoptic Segmentation image {} created with labels {} '.format(save_mask_path,labels) @tool def image_description(img_path:str)->str: "Use this tool to describe the image " \ "The tool helps you to identify weather in the image as well " hf_model = "Salesforce/blip-image-captioning-base" text = "" if img_path.startswith('https'): image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB') else: image = Image.open(img_path).convert('RGB') try: processor = BlipProcessor.from_pretrained(hf_model) caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device) except: logging.error("unable to load the Blip model ") logging.info("Image Caption model loaded ! ") # unconditional image captioning inputs = processor(image, return_tensors ='pt').to(device) output = caption_model.generate(**inputs, max_new_tokens=50) caption = processor.decode(output[0], skip_special_tokens=True) # conditional image captioning obj_text = "Total number of objects in image " inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device) out_2 = caption_model.generate(**inputs_2,max_new_tokens=50) object_caption = processor.decode(out_2[0], skip_special_tokens=True) ## clear the GPU cache with torch.no_grad(): torch.cuda.empty_cache() text = caption + " ."+ object_caption+" ." return text @tool def clipsegmentation_mask(input_data:str)->str: """ The tool helps to extract the object masks from the image. For example : If you want to extract the object masks from the image use this tool. """ data = input_data.split(",") image_path = data[0] object_prompts = data[1:] masks = clipseg_model.get_segmentation_mask(image_path,object_prompts) return masks @tool def generate_bounding_box_tool(input_data:str)->str: "use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects" yolo_world_model= YoloWorld() data = input_data.split(",") image_path = data[0] object_prompts = data[1:] object_data = yolo_world_model.run_inference(image_path,object_prompts) return object_data @tool def object_extraction(img_path:str)->str: "Use this tool to identify the objects within the image" hf_model = "Salesforce/blip-image-captioning-base" if img_path.startswith('https'): image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB') else: image = Image.open(img_path).convert('RGB') try: processor = BlipProcessor.from_pretrained(hf_model) caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device) except: logging.error("unable to load the Blip model ") logging.info("Image Caption model loaded ! ") # unconditional image captioning inputs = processor(image, return_tensors ='pt').to(device) output = caption_model.generate(**inputs, max_new_tokens=50) llm = get_groq_model() getobject_chain = create_object_extraction_chain(llm=llm) extracted_objects = getobject_chain.invoke({ 'context': processor.decode(output[0], skip_special_tokens=True) }).objects print("Extracted objects : ",extracted_objects) ## clear the GPU cache with torch.no_grad(): torch.cuda.empty_cache() return extracted_objects.split(',') @tool def get_image_quality(image_path:str)->str: """ This tool helps to find out the parameters of the image.The tool will determine if image is blurry or not. It will also tell you if image is bright or not. This tool also determines the Signal to Noise Ratio of the image as well . For example Output of the tool will be : example 1 : Image is blurry.Image is not bright.Signal to Noise is less than 1 - More Noise in image example 2 : Image is not blurry . Image is bright.Signal to Noise is greater than 1 - More Signal in image """ image = cv2.imread(image_path) image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) brightness_text = image_brightness(image) blurry_text = variance_of_laplacian(image) snr_text = get_signal_to_noise_ratio(image) final_text = "Image properties are :\n{}\n{}\n{}".format(blurry_text, brightness_text,snr_text) return final_text def get_all_tools(): ## bind tools image_desc_tool = Tool( name = 'Image_Descprtion_Tool', func= image_description, description = """ The tool helps to describe about the image or create a caption of the image If the user asks to decribe or genrerate a caption for the image use this tool. This tool can also be used to identify the weather within the image . user example questions : 1. Describe the image ? 2. What the weather looks like in the image ? """ ) clipseg_tool = Tool( name = 'ClipSegmentation-tool', func = clipsegmentation_mask, description="""Use this tool when user ask to generate the segmentation Mask of the objects provided by the user. The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated. For example : Query :Provide a segmentation mask of all road car and dog in the image The tool will generate the segmentation mask of the objects in the image. for such query from the user you need to first use the tool to identify the objects and then use this tool to generate the segmentation mask for the objects. """ ) bounding_box_generator = Tool( name = 'Bounding Box Generator', func = generate_bounding_box_tool, description= "The tool helps to provide bounding boxes for the given image and list of objects\ .Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects \ then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.\ The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated" ) object_extractor = Tool( name = "Object Extraction Tool", func = object_extraction, description = " The Tool is used to extract objects within the image . Use this tool if user specifically ask to identify \ what are the objects I can view in the image or identify the objects within the image . " ) image_parameters_tool = Tool( name = 'Image Parameters_Tool', func = get_image_quality, description= """ This tool will help you to determine - If the image is blurry or not - If the image is bright/sharp or not - SNR ratio of the image Based on the tool output take a proper decision regarding the image quality""" ) panoptic_segmentation = Tool( name = 'panoptic_Segmentation_tool', func = panoptic_image_segemntation, description = "The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \ the objects present in the image . Use the tool in case user ask to create a panoptic segmentation or count objects in the image.\ The tool also provides a list of objects along with the mask image of the all segmented objects found in the image ." ) tools = [ DuckDuckGoSearchResults(), image_desc_tool, clipseg_tool, image_parameters_tool, object_extractor, bounding_box_generator, panoptic_segmentation ] return tools