Chris4K commited on
Commit
d3b0430
1 Parent(s): aa72e55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -70
app.py CHANGED
@@ -105,7 +105,7 @@ from huggingface_hub import InferenceClient
105
 
106
  # NLTK Resource Download
107
  def download_nltk_resources():
108
- resources = ['punkt', 'stopwords', 'snowball_data']
109
  for resource in resources:
110
  try:
111
  nltk.download(resource, quiet=False)
@@ -337,7 +337,7 @@ def optimize_query(
337
  vector_store_type: str, # Added to match your signature
338
  search_type: str, # Added to match your signature
339
  top_k: int = 3,
340
- use_gpu: bool = True
341
  ) -> str:
342
  """
343
  CPU-optimized version of query expansion using a small language model.
@@ -354,7 +354,7 @@ def optimize_query(
354
 
355
  Returns:
356
  Expanded query string
357
- """
358
  try:
359
  # Set device
360
  device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
@@ -372,74 +372,60 @@ def optimize_query(
372
  expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
373
 
374
  # 3. Use provided model with reduced complexity
375
- try:
376
- # Load model with reduced memory footprint
377
- tokenizer = AutoTokenizer.from_pretrained(
378
- query_optimization_model, # Use the provided model name
379
- model_max_length=128,
380
- cache_dir="./model_cache"
381
- )
382
- model = AutoModelForSeq2Gen.from_pretrained(
383
- query_optimization_model, # Use the provided model name
384
- low_cpu_mem_usage=True,
385
- device_map="cpu"
386
- )
387
-
388
- # Move model to CPU and eval mode
389
- model = model.to(device)
390
- model.eval()
391
-
392
- # Prepare input with reduced length
393
- prompt = f"Enhance this search query with relevant terms: {query}"
394
- inputs = tokenizer(
395
- prompt,
396
- return_tensors="pt",
397
- max_length=64,
398
- truncation=True,
399
- padding=True
400
- )
401
-
402
- # Generate with minimal parameters
403
- with torch.no_grad():
404
- outputs = model.generate(
405
- inputs.input_ids.to(device),
406
- max_length=32,
407
  num_return_sequences=1,
408
- temperature=0.7,
409
- do_sample=False,
410
- early_stopping=True
411
  )
412
 
413
- enhanced_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
414
 
415
- # Clear CUDA cache if GPU was used
416
- if device == "cuda":
417
- torch.cuda.empty_cache()
418
 
419
- except Exception as model_error:
420
- print(f"Model-based expansion failed: {str(model_error)}")
421
- enhanced_query = query
422
 
423
- # 4. Combine original and expanded terms
424
- final_terms = set(tokens)
425
- final_terms.update(expanded_terms)
426
- if enhanced_query != query:
427
- final_terms.update(word_tokenize(enhanced_query.lower()))
428
 
429
  # 5. Remove stopwords and select top_k most relevant terms
430
  stopwords = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'])
431
  final_terms = [term for term in final_terms if term not in stopwords]
432
 
433
  # Combine with original query
434
- expanded_query = f"{query} {' '.join(list(final_terms)[:top_k])}"
435
-
436
  # Clean up
437
- del model
438
- del tokenizer
439
- if device == "cuda":
440
- torch.cuda.empty_cache()
441
 
442
- return expanded_query.strip() #[Document(page_content=expanded_query.strip())]
443
 
444
  except Exception as e:
445
  print(f"Query optimization failed: {str(e)}")
@@ -1073,6 +1059,7 @@ def analyze_results(stats_df):
1073
  return recommendations
1074
 
1075
  ####
 
1076
 
1077
  def get_llm_suggested_settings(file, num_chunks=1):
1078
  if not file:
@@ -1092,7 +1079,7 @@ def get_llm_suggested_settings(file, num_chunks=1):
1092
  sample_chunks = random.sample(chunks, min(num_chunks, len(chunks)))
1093
 
1094
 
1095
- llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='cuda')
1096
 
1097
 
1098
  prompt=f'''
@@ -1155,17 +1142,16 @@ def get_llm_suggested_settings(file, num_chunks=1):
1155
  max_new_tokens=1900, # Control the length of the output,
1156
  truncation=True, # Enable truncation
1157
  )
1158
-
1159
-
1160
- #suggested_settings = llm.invoke(prompt)
1161
- print("setting suggested")
1162
- print(suggested_settings)
1163
- # Parse the generated text to extract the dictionary
1164
  try:
1165
- settings_dict = eval(suggested_settings)
 
 
1166
  # Convert the settings to match the interface inputs
1167
  return {
1168
- "embedding_models": f"{settings_dict['embedding_model_type']}:{settings_dict['embedding_model_name']}",
1169
  "split_strategy": settings_dict["split_strategy"],
1170
  "chunk_size": settings_dict["chunk_size"],
1171
  "overlap_size": settings_dict["overlap_size"],
@@ -1173,13 +1159,15 @@ def get_llm_suggested_settings(file, num_chunks=1):
1173
  "search_type": settings_dict["search_type"],
1174
  "top_k": settings_dict["top_k"],
1175
  "apply_preprocessing": settings_dict["apply_preprocessing"],
1176
- "optimize_vocab": settings_dict["optimize_vocabulary"],
1177
- "apply_phonetic": settings_dict["apply_phonetic_matching"],
1178
- "phonetic_weight": 0.3 # Default value, as it's not in the LLM suggestions
1179
  }
1180
- except:
 
1181
  return {"error": "Failed to parse LLM suggestions"}
1182
 
 
1183
  def update_inputs_with_llm_suggestions(suggestions):
1184
  if suggestions is None or "error" in suggestions:
1185
  return [gr.update() for _ in range(11)] # Return no updates if there's an error or None
 
105
 
106
  # NLTK Resource Download
107
  def download_nltk_resources():
108
+ resources = ['punkt', 'stopwords', 'snowball_data', 'wordnet']
109
  for resource in resources:
110
  try:
111
  nltk.download(resource, quiet=False)
 
337
  vector_store_type: str, # Added to match your signature
338
  search_type: str, # Added to match your signature
339
  top_k: int = 3,
340
+ use_gpu: bool = False
341
  ) -> str:
342
  """
343
  CPU-optimized version of query expansion using a small language model.
 
354
 
355
  Returns:
356
  Expanded query string
357
+ """
358
  try:
359
  # Set device
360
  device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
 
372
  expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
373
 
374
  # 3. Use provided model with reduced complexity
375
+ try:
376
+ # Initialize the pipeline with the chosen model
377
+ llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='cpu')
378
+
379
+ # Define prompt for the assistant, making it context-specific
380
+ prompt = f'''
381
+ <|start_header_id|>system<|end_header_id|>
382
+ You are an expert in enhancing user input for vector store retrieval.
383
+ Enhance the followinf search query with relevant terms.
384
+
385
+ show me just the new term. You SHOULD NOT include any other text in the response.
386
+
387
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
388
+ {query}
389
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
390
+ '''
391
+
392
+ # Get suggested settings from the LLM
393
+ suggested_settings = llm_pipeline(
394
+ prompt,
395
+ do_sample=True,
396
+ top_k=10,
 
 
 
 
 
 
 
 
 
 
397
  num_return_sequences=1,
398
+ return_full_text=False,
399
+ max_new_tokens=1900, # Control the length of the output
400
+ truncation=True # Enable truncation
401
  )
402
 
403
+ # Extract the settings from the generated response
404
+ generated_text = suggested_settings[0].get('generated_text', '')
405
+ print(generated_text) # For debugging, ensure text output is as expected
406
 
407
+ except Exception as model_error:
408
+ print(f"LLM-based expansion failed: {str(model_error)}")
409
+ generated_text = "Default settings could not be generated." # Fallback message or settings
410
 
 
 
 
411
 
412
+ # 4. Combine original and expanded terms
413
+ final_terms = set(tokens)
414
+ final_terms.update(expanded_terms)
415
+ if generated_text != query:
416
+ final_terms.update(word_tokenize(generated_text.lower()))
417
 
418
  # 5. Remove stopwords and select top_k most relevant terms
419
  stopwords = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'])
420
  final_terms = [term for term in final_terms if term not in stopwords]
421
 
422
  # Combine with original query
423
+ generated_text = f"{query} {' '.join(list(final_terms)[:top_k])}"
424
+ print(generated_text)
425
  # Clean up
426
+ # llm_pipeline = None
 
 
 
427
 
428
+ return generated_text.strip() #[Document(page_content=generated_text.strip())]
429
 
430
  except Exception as e:
431
  print(f"Query optimization failed: {str(e)}")
 
1059
  return recommendations
1060
 
1061
  ####
1062
+ import ast
1063
 
1064
  def get_llm_suggested_settings(file, num_chunks=1):
1065
  if not file:
 
1079
  sample_chunks = random.sample(chunks, min(num_chunks, len(chunks)))
1080
 
1081
 
1082
+ llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='cpu')
1083
 
1084
 
1085
  prompt=f'''
 
1142
  max_new_tokens=1900, # Control the length of the output,
1143
  truncation=True, # Enable truncation
1144
  )
1145
+
1146
+ print(suggested_settings[0]['generated_text'])
1147
+ # Safely parse the generated text to extract the dictionary
 
 
 
1148
  try:
1149
+ # Using ast.literal_eval for safe parsing
1150
+ settings_dict = ast.literal_eval(suggested_settings[0]['generated_text'])
1151
+
1152
  # Convert the settings to match the interface inputs
1153
  return {
1154
+ "embedding_models": settings_dict["embedding_models"],
1155
  "split_strategy": settings_dict["split_strategy"],
1156
  "chunk_size": settings_dict["chunk_size"],
1157
  "overlap_size": settings_dict["overlap_size"],
 
1159
  "search_type": settings_dict["search_type"],
1160
  "top_k": settings_dict["top_k"],
1161
  "apply_preprocessing": settings_dict["apply_preprocessing"],
1162
+ "optimize_vocab": settings_dict["optimize_vocab"],
1163
+ "apply_phonetic": settings_dict["apply_phonetic"],
1164
+ "phonetic_weight": settings_dict.get("phonetic_weight", 0.3) # Set default if not provided
1165
  }
1166
+ except Exception as e:
1167
+ print(f"Error parsing LLM suggestions: {e}")
1168
  return {"error": "Failed to parse LLM suggestions"}
1169
 
1170
+
1171
  def update_inputs_with_llm_suggestions(suggestions):
1172
  if suggestions is None or "error" in suggestions:
1173
  return [gr.update() for _ in range(11)] # Return no updates if there's an error or None