VoucherVision / vouchervision /LLM_local_MistralAI_batch_async.py
phyloforfun's picture
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
e91ac58
raw
history blame
8.42 kB
import json, torch, transformers, gc
from transformers import BitsAndBytesConfig
from langchain.output_parsers import RetryWithErrorOutputParser
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from huggingface_hub import hf_hub_download
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
import asyncio
from utils_LLM import validate_and_align_JSON_keys_with_template, count_tokens, validate_taxonomy_WFO, validate_coordinates_here, remove_colons_and_double_apostrophes, SystemLoadMonitor
'''
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
'''
from torch.utils.data import Dataset, DataLoader
# Dataset for handling prompts
class PromptDataset(Dataset):
def __init__(self, prompts):
self.prompts = prompts
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
class LocalMistralHandler:
RETRY_DELAY = 2 # Wait 2 seconds before retrying
MAX_RETRIES = 5 # Maximum number of retries
STARTING_TEMP = 0.1
TOKENIZER_NAME = None
VENDOR = 'mistral'
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
def __init__(self, logger, model_name, JSON_dict_structure):
self.logger = logger
self.has_GPU = torch.cuda.is_available()
self.monitor = SystemLoadMonitor(logger)
self.model_name = model_name
self.model_id = f"mistralai/{self.model_name}"
name_parts = self.model_name.split('-')
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
self.JSON_dict_structure = JSON_dict_structure
self.starting_temp = float(self.STARTING_TEMP)
self.temp_increment = float(0.2)
self.adjust_temp = self.starting_temp
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
template = """
<s>[INST]{}[/INST]</s>
[INST]{}[/INST]
""".format(system_prompt, "{query}")
# Create a prompt from the template so we can use it with Langchain
self.prompt = PromptTemplate(template=template, input_variables=["query"])
# Set up a parser
self.parser = JsonOutputParser()
self._set_config()
def _clear_VRAM(self):
# Clear CUDA cache if it's being used
if self.has_GPU:
self.local_model = None
self.local_model_pipeline = None
del self.local_model
del self.local_model_pipeline
gc.collect() # Explicitly invoke garbage collector
torch.cuda.empty_cache()
else:
self.local_model_pipeline = None
self.local_model = None
del self.local_model_pipeline
del self.local_model
gc.collect() # Explicitly invoke garbage collector
def _set_config(self):
self._clear_VRAM()
self.config = {'max_new_tokens': 1024,
'temperature': self.starting_temp,
'seed': 2023,
'top_p': 1,
'top_k': 40,
'do_sample': True,
'n_ctx':4096,
# Activate 4-bit precision base model loading
'use_4bit': True,
# Compute dtype for 4-bit base models
'bnb_4bit_compute_dtype': "float16",
# Quantization type (fp4 or nf4)
'bnb_4bit_quant_type': "nf4",
# Activate nested quantization for 4-bit base models (double quantization)
'use_nested_quant': False,
}
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=self.config.get('use_4bit'),
bnb_4bit_quant_type=self.config.get('bnb_4bit_quant_type'),
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
)
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
# print("=" * 80)
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
# print("=" * 80)
self.b_float_opt = torch.bfloat16
else:
self.b_float_opt = torch.float16
self._build_model_chain_parser()
def _adjust_config(self):
self.logger.info(f'Incrementing temperature and reloading model')
self._clear_VRAM()
self.adjust_temp += self.temp_increment
self.config['temperature'] = self.adjust_temp
self._build_model_chain_parser()
def _build_model_chain_parser(self):
self.local_model_pipeline = transformers.pipeline("text-generation",
model=self.model_id,
max_new_tokens=self.config.get('max_new_tokens'),
temperature=self.config.get('temperature'),
top_k=self.config.get('top_k'),
top_p=self.config.get('top_p'),
do_sample=self.config.get('do_sample'),
model_kwargs={"torch_dtype": self.b_float_opt,
"load_in_4bit": True,
"quantization_config": self.bnb_config})
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
# Set up the retry parser with the runnable
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
# Create an llm chain with LLM and prompt
self.chain = self.prompt | self.local_model
def call_llm_local_MistralAI(self, prompts, batch_size=2):
# Wrap the async call with asyncio.run
dataset = PromptDataset(prompts)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
all_results = asyncio.run(self._process_all_batches(data_loader))
if self.adjust_temp != self.starting_temp:
self._set_config()
return all_results
async def _process_batch(self, batch_prompts):
# Create and manage async tasks for each prompt in the batch
tasks = [self._process_single_prompt(prompt) for prompt in batch_prompts]
return await asyncio.gather(*tasks)
async def _process_all_batches(self, data_loader):
# Process all batches asynchronously
results = []
for batch_prompts in data_loader:
batch_results = await self._process_batch(batch_prompts)
results.extend(batch_results)
return results
async def _process_single_prompt(self, prompt_template):
self.monitor.start_monitoring_usage()
nt_in = nt_out = 0
ind = 0
while ind < self.MAX_RETRIES:
ind += 1
results = self.chain.invoke({"query": prompt_template})
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
if output is None:
self.logger.error(f'Failed to extract JSON from:\n{results}')
self._adjust_config()
del results
else:
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False)
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False)
self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
del results
self.monitor.stop_monitoring_report_usage()
return output, nt_in, nt_out, WFO_record, GEO_record
self.monitor.stop_monitoring_report_usage()
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
return None, nt_in, nt_out, None, None