Spaces:
Running
Running
File size: 8,423 Bytes
e91ac58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import json, torch, transformers, gc
from transformers import BitsAndBytesConfig
from langchain.output_parsers import RetryWithErrorOutputParser
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from huggingface_hub import hf_hub_download
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
import asyncio
from utils_LLM import validate_and_align_JSON_keys_with_template, count_tokens, validate_taxonomy_WFO, validate_coordinates_here, remove_colons_and_double_apostrophes, SystemLoadMonitor
'''
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
'''
from torch.utils.data import Dataset, DataLoader
# Dataset for handling prompts
class PromptDataset(Dataset):
def __init__(self, prompts):
self.prompts = prompts
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
class LocalMistralHandler:
RETRY_DELAY = 2 # Wait 2 seconds before retrying
MAX_RETRIES = 5 # Maximum number of retries
STARTING_TEMP = 0.1
TOKENIZER_NAME = None
VENDOR = 'mistral'
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
def __init__(self, logger, model_name, JSON_dict_structure):
self.logger = logger
self.has_GPU = torch.cuda.is_available()
self.monitor = SystemLoadMonitor(logger)
self.model_name = model_name
self.model_id = f"mistralai/{self.model_name}"
name_parts = self.model_name.split('-')
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
self.JSON_dict_structure = JSON_dict_structure
self.starting_temp = float(self.STARTING_TEMP)
self.temp_increment = float(0.2)
self.adjust_temp = self.starting_temp
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
template = """
<s>[INST]{}[/INST]</s>
[INST]{}[/INST]
""".format(system_prompt, "{query}")
# Create a prompt from the template so we can use it with Langchain
self.prompt = PromptTemplate(template=template, input_variables=["query"])
# Set up a parser
self.parser = JsonOutputParser()
self._set_config()
def _clear_VRAM(self):
# Clear CUDA cache if it's being used
if self.has_GPU:
self.local_model = None
self.local_model_pipeline = None
del self.local_model
del self.local_model_pipeline
gc.collect() # Explicitly invoke garbage collector
torch.cuda.empty_cache()
else:
self.local_model_pipeline = None
self.local_model = None
del self.local_model_pipeline
del self.local_model
gc.collect() # Explicitly invoke garbage collector
def _set_config(self):
self._clear_VRAM()
self.config = {'max_new_tokens': 1024,
'temperature': self.starting_temp,
'seed': 2023,
'top_p': 1,
'top_k': 40,
'do_sample': True,
'n_ctx':4096,
# Activate 4-bit precision base model loading
'use_4bit': True,
# Compute dtype for 4-bit base models
'bnb_4bit_compute_dtype': "float16",
# Quantization type (fp4 or nf4)
'bnb_4bit_quant_type': "nf4",
# Activate nested quantization for 4-bit base models (double quantization)
'use_nested_quant': False,
}
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=self.config.get('use_4bit'),
bnb_4bit_quant_type=self.config.get('bnb_4bit_quant_type'),
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
)
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
# print("=" * 80)
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
# print("=" * 80)
self.b_float_opt = torch.bfloat16
else:
self.b_float_opt = torch.float16
self._build_model_chain_parser()
def _adjust_config(self):
self.logger.info(f'Incrementing temperature and reloading model')
self._clear_VRAM()
self.adjust_temp += self.temp_increment
self.config['temperature'] = self.adjust_temp
self._build_model_chain_parser()
def _build_model_chain_parser(self):
self.local_model_pipeline = transformers.pipeline("text-generation",
model=self.model_id,
max_new_tokens=self.config.get('max_new_tokens'),
temperature=self.config.get('temperature'),
top_k=self.config.get('top_k'),
top_p=self.config.get('top_p'),
do_sample=self.config.get('do_sample'),
model_kwargs={"torch_dtype": self.b_float_opt,
"load_in_4bit": True,
"quantization_config": self.bnb_config})
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
# Set up the retry parser with the runnable
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
# Create an llm chain with LLM and prompt
self.chain = self.prompt | self.local_model
def call_llm_local_MistralAI(self, prompts, batch_size=2):
# Wrap the async call with asyncio.run
dataset = PromptDataset(prompts)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
all_results = asyncio.run(self._process_all_batches(data_loader))
if self.adjust_temp != self.starting_temp:
self._set_config()
return all_results
async def _process_batch(self, batch_prompts):
# Create and manage async tasks for each prompt in the batch
tasks = [self._process_single_prompt(prompt) for prompt in batch_prompts]
return await asyncio.gather(*tasks)
async def _process_all_batches(self, data_loader):
# Process all batches asynchronously
results = []
for batch_prompts in data_loader:
batch_results = await self._process_batch(batch_prompts)
results.extend(batch_results)
return results
async def _process_single_prompt(self, prompt_template):
self.monitor.start_monitoring_usage()
nt_in = nt_out = 0
ind = 0
while ind < self.MAX_RETRIES:
ind += 1
results = self.chain.invoke({"query": prompt_template})
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
if output is None:
self.logger.error(f'Failed to extract JSON from:\n{results}')
self._adjust_config()
del results
else:
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False)
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False)
self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
del results
self.monitor.stop_monitoring_report_usage()
return output, nt_in, nt_out, WFO_record, GEO_record
self.monitor.stop_monitoring_report_usage()
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
return None, nt_in, nt_out, None, None |