import httpx import re import requests import time from enum import Enum from pydantic import BaseModel, Field from typing_extensions import List from typing import Literal, Optional from requests.exceptions import ConnectionError from PIL import Image VLM_TEMPERATURE = 0 ###################################################################### ###################################################################### class WeightUnit(Enum): GRAMM = "Gramm" KILOGRAM = "Kilogramm" MILLILITER = "Milliliter" LITER = "Liter" WASCHLADUNGEN = "Waschladungen" BLATT = "Blatt" STUECK = "Stück" class YesNo(Enum): YES = "yes" NO = "no" class product_promotion_data(BaseModel): """Collection of product and promotion data of an product advertisement.""" brand: str = Field(description="The brand associated with the product") product_category: List[str] = Field(description="List of categories associated with the product.") price: float = Field(description="The promotional price.") regular_price: Optional[float] = Field(default=None, description="The regular price of the promotion.") relative_discount: Optional[int] = Field(default=None, description="The relative discount of the promotion.") absolute_discount: Optional[float] = Field(default=None, description="The absolute discount of the promotion.") GTINs: List[str] = Field(description="List of the GTINs for the products.") weight_number: float = Field(description="Only the numerical weight specication.") # weight_unit: WeightUnit = Field(description="Only the weight unit.") weight_unit: Literal["Gramm", "Kilogramm", "Milliliter", "Liter", "Waschladungen", "Blatt", "Stück"] = Field(description="Only the weight unit.") # different_types: YesNo = Field(description="If promotion offer different sorts.") different_types: Literal["yes", "no"] = Field(description="If promotion offer different sorts.") ###################################################################### ###################################################################### def convert_items_to_strings(prediction): if isinstance(prediction, str): return prediction elif isinstance(prediction, list): return ', '.join(prediction) else: return str(prediction) def get_output_results(dict_output, dict_result): for key, value in dict_output.items(): if key == 'brand': dict_result['brand'] = convert_items_to_strings(dict_output['brand']) elif key == 'product_category': dict_result['product_category'] = convert_items_to_strings(dict_output['product_category']) elif key == 'price': dict_result['price'] = convert_items_to_strings(dict_output['price']) elif key == 'regular_price': dict_result['regular_price'] = convert_items_to_strings(dict_output['regular_price']) elif key == 'relative_discount': dict_result['relative_discount'] = convert_items_to_strings(dict_output['relative_discount']) elif key == 'absolute_discount': dict_result['absolute_discount'] = convert_items_to_strings(dict_output['absolute_discount']) elif key == 'GTINs': dict_result['GTINs'] = convert_items_to_strings(dict_output['GTINs']) elif key == 'weight_number': dict_result['weight_number'] = convert_items_to_strings(dict_output['weight_number']) elif key == 'weight_unit': dict_result['weight_unit'] = convert_items_to_strings(dict_output['weight_unit']) elif key == 'different_types': dict_result['different_types'] = convert_items_to_strings(dict_output['different_types']) return dict_result def prompt(query_image, task, dict_log): system_message = "You are an assistant for question-answering tasks." dict_log['system_message'] = system_message human_message_text = "Do the user-provided task on the input image. \ The answer must be provided in JSON format. \ The task is: " + task + ".\ If there is no information of a target, return NaN." dict_log['human_message_text'] = human_message_text input_messages = [ { "role": "system", "content": [{"type": "text", "text": system_message}], }, { "role": "user", "content": [ { "type": "image", "image": query_image, }, { "type": "text", "text": human_message_text, }, ], }, ] return dict_log, input_messages def process_vision_info(messages: list[dict]) -> list[Image.Image]: image_inputs = [] # Iterate through each conversation for msg in messages: # Get content (ensure it's a list) content = msg.get("content", []) if not isinstance(content, list): content = [content] # Check each content element for images for element in content: if isinstance(element, dict) and ( "image" in element or element.get("type") == "image" ): # Get the image and convert to RGB if "image" in element: image = element["image"] else: image = element image_inputs.append(image.convert("RGB")) return image_inputs def get_dict_from_output_text(output_text): # Remove the surrounding braces: trimmed = output_text[0].strip('{}').strip() # Find all keys with their start positions: # Key pattern: word characters followed by colon matches = list(re.finditer(r'(\b\w+\b)\s*:', trimmed)) data = {} for i, match in enumerate(matches): key = match.group(1) start = match.end() # position after colon # end is start of next key or end of string if i+1 < len(matches): end = matches[i+1].start() else: end = len(trimmed) # The value is substring from start:end value = trimmed[start:end].strip().rstrip(',') # Clean value - strip whitespace and trailing commas value = value.strip() data[key] = value return data def do_request(ft_model, processor, pil_image, task, dict_log, dict_result, dict_result_cost): dict_log, messages = prompt(pil_image, task, dict_log) text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process the image and text image_inputs = process_vision_info(messages) # Tokenize the text and process the images inputs = processor( text=[text], images=image_inputs, padding=True, return_tensors="pt", ) # Move the inputs to the device inputs = inputs.to(ft_model.device) stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("")] try: start_time = time.time() try: # Generate the output generated_ids = ft_model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True) except: print("FAILED") elapsed_time = time.time() - start_time dict_result_cost['elapsed_time_[s]'] = float("{:.2f}".format(elapsed_time)) # Trim the generation and decode the output to text generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) if len(output_text) == 1: dict_output = get_dict_from_output_text(output_text) dict_result = get_output_results(dict_output, dict_result) print('dict_result') print(dict_result) except ConnectionError as e: print(f"Connection error occurred: {e}") return dict_log, dict_result, dict_result_cost except requests.exceptions.RequestException as e: print(f"An error occurred: {e}") return dict_log, dict_result, dict_result_cost except ValueError as ve: print(f"Validation error: {ve}") return dict_log, dict_result, dict_result_cost except httpx.HTTPStatusError as e: print(f"HTTPStatusError: {e}") time.sleep(60) return dict_log, dict_result, dict_result_cost return dict_log, dict_result, dict_result_cost