Spaces:
Running
Running
import os, time, json, typing | |
# import vertexai | |
from vertexai.language_models import TextGenerationModel | |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold | |
from vertexai.language_models import TextGenerationModel | |
# from vertexai.preview.generative_models import GenerativeModel | |
from langchain.output_parsers import RetryWithErrorOutputParser | |
# from langchain.schema import HumanMessage | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import JsonOutputParser | |
# from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_google_vertexai import VertexAI | |
from langchain_core.messages import BaseMessage, HumanMessage | |
from langchain_core.prompt_values import PromptValue as BasePromptValue | |
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt | |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template | |
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk | |
#pip install --upgrade google-cloud-aiplatform | |
# from google.cloud import aiplatform | |
#### have to authenticate gcloud | |
# gcloud auth login | |
# gcloud config set project XXXXXXXXX | |
# https://cloud.google.com/docs/authentication | |
class GooglePalm2Handler: | |
RETRY_DELAY = 10 # Wait 10 seconds before retrying | |
MAX_RETRIES = 3 # Maximum number of retries | |
TOKENIZER_NAME = 'gpt-4' | |
VENDOR = 'google' | |
STARTING_TEMP = 0.5 | |
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation): | |
self.cfg = cfg | |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO'] | |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO'] | |
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia'] | |
self.logger = logger | |
self.model_name = model_name | |
self.JSON_dict_structure = JSON_dict_structure | |
self.config_vals_for_permutation = config_vals_for_permutation | |
self.monitor = SystemLoadMonitor(logger) | |
self.parser = JsonOutputParser() | |
# Define the prompt template | |
self.prompt = PromptTemplate( | |
template="Answer the user query.\n{format_instructions}\n{query}\n", | |
input_variables=["query"], | |
partial_variables={"format_instructions": self.parser.get_format_instructions()}, | |
) | |
self._set_config() | |
def _set_config(self): | |
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION']) | |
if self.config_vals_for_permutation: | |
self.starting_temp = float(self.config_vals_for_permutation.get('google').get('temperature')) | |
self.config = { | |
'max_output_tokens': self.config_vals_for_permutation.get('google').get('max_output_tokens'), | |
'temperature': self.starting_temp, | |
'top_k': self.config_vals_for_permutation.get('google').get('top_k'), | |
'top_p': self.config_vals_for_permutation.get('google').get('top_p'), | |
} | |
else: | |
self.starting_temp = float(self.STARTING_TEMP) | |
self.config = { | |
"max_output_tokens": 1024, | |
"temperature": self.starting_temp, | |
"top_k": 1, | |
"top_p": 1.0, | |
} | |
self.temp_increment = float(0.2) | |
self.adjust_temp = self.starting_temp | |
self.safety_settings = { | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
} | |
self._build_model_chain_parser() | |
def _adjust_config(self): | |
new_temp = self.adjust_temp + self.temp_increment | |
if self.json_report: | |
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}') | |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}') | |
self.adjust_temp += self.temp_increment | |
self.config['temperature'] = self.adjust_temp | |
def _reset_config(self): | |
if self.json_report: | |
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}') | |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}') | |
self.adjust_temp = self.starting_temp | |
self.config['temperature'] = self.starting_temp | |
def _build_model_chain_parser(self): | |
# Instantiate the parser and the retry parser | |
# self.llm_model = ChatGoogleGenerativeAI(model=self.model_name) | |
self.llm_model = VertexAI(model=self.model_name, | |
max_output_tokens=self.config.get('max_output_tokens'), | |
temperature=self.config.get('temperature'), | |
top_k=self.config.get('top_k'), | |
top_p=self.config.get('top_p')) | |
self.retry_parser = RetryWithErrorOutputParser.from_llm( | |
parser=self.parser, | |
llm=self.llm_model, | |
max_retries=self.MAX_RETRIES) | |
# Prepare the chain | |
self.chain = self.prompt | self.call_google_palm2 | |
# Define a function to format the input for Google PaLM call | |
# https://cloud.google.com/vertex-ai/docs/generative-ai/migrate/migrate-palm-to-gemini?_ga=2.225326234.-1652490527.1705461451&_gac=1.186295771.1706291573.CjwKCAiAzc2tBhA6EiwArv-i6QCpx7xTP0yrBy9KKSwno3QXOWUe14mbp9RGZO0ShcbtFqyXii2PnRoCywgQAvD_BwE | |
def call_google_palm2(self, prompt_text): | |
model = TextGenerationModel.from_pretrained(self.model_name) | |
response = model.predict(prompt_text.text, | |
max_output_tokens=self.config.get('max_output_tokens'), | |
temperature=self.config.get('temperature'), | |
top_k=self.config.get('top_k'), | |
top_p=self.config.get('top_p')) | |
# model = GenerativeModel(self.model_name) | |
# response = model.generate_content(prompt_text.text,generation_config=self.config, safety_settings=self.safety_settings, stream=False) | |
return response.text | |
def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths): | |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths | |
self.json_report = json_report | |
if json_report: | |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}') | |
self.monitor.start_monitoring_usage() | |
nt_in = 0 | |
nt_out = 0 | |
ind = 0 | |
while ind < self.MAX_RETRIES: | |
ind += 1 | |
try: | |
# model_kwargs = {"temperature": self.adjust_temp} | |
# Invoke the chain to generate prompt text | |
response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs}) | |
# Use retry_parser to parse the response with retry logic | |
try: | |
output = self.retry_parser.parse_with_prompt(response, prompt_value=PromptValue(prompt_template)) | |
except: | |
try: | |
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template) | |
except: | |
try: | |
output = json.loads(response) | |
except Exception as e: | |
print(e) | |
output = None | |
if output is None: | |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}') | |
self._adjust_config() | |
else: | |
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME) | |
nt_out = count_tokens(response, self.VENDOR, self.TOKENIZER_NAME) | |
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure) | |
if output is None: | |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}') | |
self._adjust_config() | |
else: | |
self.monitor.stop_inference_timer() # Starts tool timer too | |
if self.json_report: | |
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links') | |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki) | |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt) | |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}") | |
usage_report = self.monitor.stop_monitoring_report_usage() | |
if self.adjust_temp != self.starting_temp: | |
self._reset_config() | |
if self.json_report: | |
self.json_report.set_text(text_main=f'LLM call successful') | |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report | |
except Exception as e: | |
self.logger.error(f'{e}') | |
self._adjust_config() | |
time.sleep(self.RETRY_DELAY) | |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts") | |
if self.json_report: | |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts') | |
self.monitor.stop_inference_timer() # Starts tool timer too | |
usage_report = self.monitor.stop_monitoring_report_usage() | |
self._reset_config() | |
if self.json_report: | |
self.json_report.set_text(text_main=f'LLM call failed') | |
return None, nt_in, nt_out, None, None, usage_report | |
class PromptValue(BasePromptValue): | |
prompt_str: str | |
def to_string(self) -> str: | |
return self.prompt_str |