import os import torch from openai import OpenAI from termcolor import colored import transformers # from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login # environment variables and paths from .env_utils import get_device, low_vram_mode device = get_device() class GPT: def __init__(self, model="gpt-4o-mini", api_key=None): self.prices = { # check at https://openai.com/api/pricing/ "gpt-3.5-turbo-0125": [0.0000005, 0.0000015], "gpt-4o-mini" : [0.00000015, 0.00000060], "gpt-4-1106-preview": [0.00001, 0.00003], "gpt-4-0125-preview": [0.00001, 0.00003], "gpt-4-turbo" : [0.00001, 0.00003], "gpt-4o" : [0.000005, 0.000015], } self.cheaper_model = "gpt-4o-mini" assert model in self.prices.keys(), "Invalid model, please choose from: {}, or add new models in the code.".format(self.prices.keys()) self.model = model print(f"Using {model}") self.client = OpenAI(api_key=api_key) self.total_cost = 0.0 def _update(self, response, price): current_cost = response.usage.completion_tokens * price[0] + response.usage.prompt_tokens * price[1] self.total_cost += current_cost # print in 4 decimal places print( colored( f"Current Tokens: {response.usage.completion_tokens + response.usage.prompt_tokens:d} \ Current cost: {current_cost:.4f} $, \ Total cost: {self.total_cost:.4f} $", "yellow", ) ) def chat(self, messages, temperature=0.0, max_tokens=200, post=False): # set temperature to 0.0 for more deterministic results if post: # use cheaper model for post-refinement to save costs, since the task is simpler. generated_text = self.client.chat.completions.create( model=self.cheaper_model, messages=messages, temperature=temperature, max_tokens=max_tokens ) self._update(generated_text, self.prices[self.cheaper_model]) else: generated_text = self.client.chat.completions.create( model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens ) self._update(generated_text, self.prices[self.model]) generated_text = generated_text.choices[0].message.content return generated_text class Llama3: def __init__(self, model="Meta-Llama-3-8B-Instruct"): login(token=os.getenv('HF_TOKEN')) model = "meta-llama/{}".format(model) # or replace with your local model path print(f"Using {model}") # ZeroGPU does not support quantization. # tokenizer = AutoTokenizer.from_pretrained(model) # if low_vram_mode: # model = AutoModelForCausalLM.from_pretrained( # model, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto" # ).eval() self.pipeline = transformers.pipeline( "text-generation", model = model, # tokenizer = tokenizer, model_kwargs = {"torch_dtype": torch.bfloat16}, device_map = "auto", ) self.terminators = [self.pipeline.tokenizer.eos_token_id, self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")] def _update(self): print(colored("Using Llama-3, Free", "green")) def chat(self, messages, temperature=0.0, max_tokens=200, post=False): prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) generated_text = self.pipeline( prompt, max_new_tokens = max_tokens, eos_token_id = self.terminators, pad_token_id = 128001, do_sample = True, temperature = max(temperature, 0.01), # 0.0 is not supported top_p = 0.9, ) self._update() generated_text = generated_text[0]["generated_text"][len(prompt) :] return generated_text # Define the timeout handler def timeout_handler(signum, frame): raise TimeoutError() def init_model(model, api_key=None): if "gpt" in model: return GPT(model=model, api_key=api_key) elif "Llama" in model: return Llama3(model=model) else: raise ValueError("Invalid model") def _generate_example_prompt(examples, llm=None): # system prompt system_prompt = """ Task Description: - you will provide detailed explanations for example inputs and outputs within the context of the task. Please adhere to the following rules: - Exclude terms that appear in both lists. - Detail the relevance of unmatched terms from input to output, focusing on indirect relationships. - Identify and explain terms common to all output lists but rarely present in input lists; include these at the end of the output labeled 'Recommend Include Labels'. - Each explanation should be concise, around 50 words. Output Format: - '1. Input... Output... Explanation... n. Input... Output... Explanation... \n Recommend Include Labels: label1, labeln, ...' """ messages = [ {"role": "system", "content": system_prompt}, { "role": "user", "content": f"Here are the input and output lists for which you need to provide detailed explanations:{examples.strip()}", }, ] generated_example = llm.chat(messages, temperature=0.0, max_tokens=1000) return generated_example def _make_prompt(label_list, example=None): Cityscape = "sidewalk" in label_list if Cityscape: add_text = f'contain at least {len(label_list.split(", "))} labels, ' else: add_text = "" # Task description and instructions for processing the input to generate output system_prompt = f""" Task Description: - You will receive a list of caption tags accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}". Instructions: Step 1. Visualize the scene suggested by the input caption tags and text. Step 2. Analyze each term within the overall scene to predict relevant labels from the predefined list, ensuring no term is overlooked. Step 3. Now forget the input list and focus on the scene as a whole, expanding upon the labels to include any contextually relevant labels that complete the scene or setting. Step 4. Compile all identified labels into a comma-separated list, adhering strictly to the specified format. Contextually Relevant Tips: - Equivalencies include converting "girl, man" to "person" and "flower, vase" to "potted plant", while "bicycle, motorcycle" suggest "rider". - An outdoor scene may include labels like "sky", "tree", "clouds", "terrain". - An urban scene may imply "bus", "bicycle", "road", "sidewalk", "building", "pole", "traffic-light", "traffic-sign". Output: - Do not output any explanations other than the final label list. - The final output should {add_text}strictly adhere to the specified format: label1, label2, ... labeln """.strip() if example: system_prompt += f""" Additional Examples with Detailed Explanations: {example} """ print("system_prompt: ", system_prompt) return system_prompt # - You will receive a list of terms accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}". # Instructions: # Step 1. Visualize the scene suggested by the input list and caption text. def make_prompt(label_list): # Create a new system prompt using the label list and the improved example prompt system_prompt = _make_prompt(label_list) system_prompt = {"role": "system", "content": system_prompt.strip()} print("system_prompt: ", system_prompt) return system_prompt def _call_llm(system_prompt, llm, user_input): messages = [system_prompt, {"role": "user", "content": "Here are input caption tags and text: " + user_input}] converted_label = llm.chat(messages=messages, temperature=0.0, max_tokens=200) return converted_label def pre_refinement(user_input_list, system_prompt, llm=None): llm_outputs = [_call_llm(system_prompt, llm, user_input) for user_input in user_input_list] converted_labels = [f"{user_input_}, {converted_label}" for user_input_, converted_label in zip(user_input_list, llm_outputs)] return converted_labels, llm_outputs def post_refinement(label_list, detected_label, llm=None): system_input = f""" Task Description: - You will receive a specific phrase and must assign an appropriate label from the predefined label list: "{label_list}". \n \ Please adhere to the following rules: \n \ - Select and return only one relevant label from the predefined label list that corresponds to the given phrase. \n \ - Do not include any additional information or context beyond the label itself. \n \ - Format is purely the label itself, without any additional punctuation or formatting. \n \ """ system_input = {"role": "system", "content": system_input} messages = [system_input, {"role": "user", "content": detected_label}] if detected_label == "": return "" generated_label = None for count in range(3): generated_label = llm.chat(messages=messages, temperature=0.0 if count == 0 else 0.1 * (count), post=True) if generated_label != "": break return generated_label if __name__ == "__main__": # test the functions llm = Llama3(model="Meta-Llama-3-8B-Instruct") system_prompt = make_prompt("person, car, tree, sky, road, building, sidewalk, traffic-light, traffic-sign", llm=llm) converted_labels, llm_outputs = pre_refinement(["person, car, road, traffic-light"], system_prompt, llm=llm) print("converted_labels: ", converted_labels) print("llm_outputs: ", llm_outputs)