import os import sys import inspect import json from json import JSONDecodeError import tiktoken import random import google.generativeai as palm currentdir = os.path.dirname(os.path.abspath( inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(currentdir) sys.path.append(parentdir) from prompt_catalog import PromptCatalog from general_utils import num_tokens_from_string """ DEPRECATED: Safety setting regularly block a response, so set to 4 to disable class HarmBlockThreshold(Enum): HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 BLOCK_LOW_AND_ABOVE = 1 BLOCK_MEDIUM_AND_ABOVE = 2 BLOCK_ONLY_HIGH = 3 BLOCK_NONE = 4 """ SAFETY_SETTINGS = [ { "category": "HARM_CATEGORY_DEROGATORY", "threshold": "BLOCK_NONE", }, { "category": "HARM_CATEGORY_TOXICITY", "threshold": "BLOCK_NONE", }, { "category": "HARM_CATEGORY_VIOLENCE", "threshold": "BLOCK_NONE", }, { "category": "HARM_CATEGORY_SEXUAL", "threshold": "BLOCK_NONE", }, { "category": "HARM_CATEGORY_MEDICAL", "threshold": "BLOCK_NONE", }, { "category": "HARM_CATEGORY_DANGEROUS", "threshold": "BLOCK_NONE", }, ] PALM_SETTINGS = { 'model': 'models/text-bison-001', 'temperature': 0, 'candidate_count': 1, 'top_k': 40, 'top_p': 0.95, 'max_output_tokens': 8000, 'stop_sequences': [], 'safety_settings': SAFETY_SETTINGS, } PALM_SETTINGS_REDO = { 'model': 'models/text-bison-001', 'temperature': 0.05, 'candidate_count': 1, 'top_k': 40, 'top_p': 0.95, 'max_output_tokens': 8000, 'stop_sequences': [], 'safety_settings': SAFETY_SETTINGS, } def OCR_to_dict_PaLM(logger, OCR, prompt_version, VVE): try: logger.info(f'Length of OCR raw -- {len(OCR)}') except: print(f'Length of OCR raw -- {len(OCR)}') # prompt = PROMPT_PaLM_UMICH_skeleton_all_asia(OCR, in_list, out_list) # must provide examples to PaLM differently than for chatGPT, at least 2 examples Prompt = PromptCatalog(OCR) if prompt_version in ['prompt_v2_palm2']: version = 'v2' prompt = Prompt.prompt_v2_palm2(OCR) elif prompt_version in ['prompt_v1_palm2',]: version = 'v1' # create input: output: for PaLM # Find a similar example from the domain knowledge domain_knowledge_example = VVE.query_db(OCR, 4) similarity= VVE.get_similarity() domain_knowledge_example_string = json.dumps(domain_knowledge_example) in_list, out_list = create_OCR_analog_for_input(domain_knowledge_example) prompt = Prompt.prompt_v1_palm2(in_list, out_list, OCR) elif prompt_version in ['prompt_v1_palm2_noDomainKnowledge',]: version = 'v1' prompt = Prompt.prompt_v1_palm2_noDomainKnowledge(OCR) else: version = 'custom' prompt, n_fields, xlsx_headers = Prompt.prompt_v2_custom(prompt_version, OCR=OCR, is_palm=True) # raise nt = num_tokens_from_string(prompt, "cl100k_base") # try: logger.info(f'Prompt token length --- {nt}') # except: # print(f'Prompt token length --- {nt}') do_use_SOP = False ######## if do_use_SOP: '''TODO: Check back later to see if LangChain will support PaLM''' # logger.info(f'Waiting for PaLM API call --- Using StructuredOutputParser') # response = structured_output_parser(OCR, prompt, logger) # return response['Dictionary'] pass else: # try: logger.info(f'Waiting for PaLM 2 API call') # except: # print(f'Waiting for PaLM 2 API call --- Content') # safety_thresh = 4 # PaLM_settings = {'model': 'models/text-bison-001','temperature': 0,'candidate_count': 1,'top_k': 40,'top_p': 0.95,'max_output_tokens': 8000,'stop_sequences': [], # 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":safety_thresh},{"category":"HARM_CATEGORY_TOXICITY","threshold":safety_thresh},{"category":"HARM_CATEGORY_VIOLENCE","threshold":safety_thresh},{"category":"HARM_CATEGORY_SEXUAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_MEDICAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_DANGEROUS","threshold":safety_thresh}],} response = palm.generate_text(prompt=prompt, **PALM_SETTINGS) if response and response.result: if isinstance(response.result, (str, bytes)): response_valid = check_and_redo_JSON(response, logger, version) else: response_valid = {} else: response_valid = {} logger.info(f'Candidate JSON\n{response.result}') return response_valid, nt def check_and_redo_JSON(response, logger, version): try: response_valid = json.loads(response.result) logger.info(f'Response --- First call passed') return response_valid except JSONDecodeError: try: response_valid = json.loads(response.result.strip('```').replace('json\n', '', 1).replace('json', '', 1)) logger.info(f'Response --- Manual removal of ```json succeeded') return response_valid except: logger.info(f'Response --- First call failed. Redo...') Prompt = PromptCatalog() if version == 'v1': prompt_redo = Prompt.prompt_palm_redo_v1(response.result) elif version == 'v2': prompt_redo = Prompt.prompt_palm_redo_v2(response.result) elif version == 'custom': prompt_redo = Prompt.prompt_v2_custom_redo(response.result, is_palm=True) # prompt_redo = PROMPT_PaLM_Redo(response.result) try: response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS) response_valid = json.loads(response.result) logger.info(f'Response --- Second call passed') return response_valid except JSONDecodeError: logger.info(f'Response --- Second call failed. Final redo. Temperature changed to 0.05') try: response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS_REDO) response_valid = json.loads(response.result) logger.info(f'Response --- Third call passed') return response_valid except JSONDecodeError: return None def create_OCR_analog_for_input(domain_knowledge_example): in_list = [] out_list = [] # Iterate over the domain_knowledge_example (list of dictionaries) for row_dict in domain_knowledge_example: # Convert the dictionary to a JSON string and add it to the out_list domain_knowledge_example_string = json.dumps(row_dict) out_list.append(domain_knowledge_example_string) # Create a single string from all values in the row_dict row_text = '||'.join(str(v) for v in row_dict.values()) # Split the row text by '||', shuffle the parts, and then re-join with a single space parts = row_text.split('||') random.shuffle(parts) shuffled_text = ' '.join(parts) # Add the shuffled_text to the in_list in_list.append(shuffled_text) return in_list, out_list def strip_problematic_chars(s): return ''.join(c for c in s if c.isprintable())