Chris4K commited on
Commit
075fdaa
1 Parent(s): 83c4a82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -96,9 +96,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
96
  #####
97
  from huggingface_hub import InferenceClient
98
 
99
- repo_id = "meta-llama/Llama-3.2-1B-Instruct"
100
 
101
- llm = InferenceClient(model=repo_id, timeout=120)
102
 
103
  # Test your LLM client
104
  #llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)
@@ -108,7 +108,7 @@ def download_nltk_resources():
108
  resources = ['punkt', 'stopwords', 'snowball_data']
109
  for resource in resources:
110
  try:
111
- nltk.download(resource, quiet=True)
112
  except Exception as e:
113
  print(f"Failed to download {resource}: {str(e)}")
114
 
@@ -331,21 +331,25 @@ import nltk
331
 
332
  def optimize_query(
333
  query: str,
 
334
  chunks: List[str],
335
  embedding_model: str,
 
 
336
  top_k: int = 3,
337
- model_name: str = "google/flan-t5-small", # Small model (only 80M parameters)
338
- use_gpu: bool = False # Default to CPU
339
  ) -> str:
340
  """
341
  CPU-optimized version of query expansion using a small language model.
342
 
343
  Args:
344
  query: Original search query
 
345
  chunks: List of text chunks to search through
346
  embedding_model: Name of the embedding model being used
 
 
347
  top_k: Number of expansion terms to add
348
- model_name: Name of the small language model to use
349
  use_gpu: Whether to use GPU if available (defaults to False for CPU)
350
 
351
  Returns:
@@ -367,42 +371,42 @@ def optimize_query(
367
  # Limit number of lemmas
368
  expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
369
 
370
- # 3. Use small T5 model with reduced complexity
371
  try:
372
  # Load model with reduced memory footprint
373
  tokenizer = AutoTokenizer.from_pretrained(
374
- model_name,
375
- model_max_length=128, # Limit maximum sequence length
376
- cache_dir="./model_cache" # Cache models locally
377
  )
378
- model = AutoModel.from_pretrained(
379
- model_name,
380
- low_cpu_mem_usage=True, # Enable low memory usage
381
- device_map="cpu" # Explicitly set to CPU
382
  )
383
 
384
  # Move model to CPU and eval mode
385
  model = model.to(device)
386
- model.eval() # Set to evaluation mode to reduce memory usage
387
 
388
  # Prepare input with reduced length
389
  prompt = f"Enhance this search query with relevant terms: {query}"
390
  inputs = tokenizer(
391
  prompt,
392
  return_tensors="pt",
393
- max_length=64, # Reduced from 128
394
  truncation=True,
395
  padding=True
396
  )
397
 
398
  # Generate with minimal parameters
399
- with torch.no_grad(): # Disable gradient calculation
400
  outputs = model.generate(
401
  inputs.input_ids.to(device),
402
- max_length=32, # Reduced from 64
403
  num_return_sequences=1,
404
  temperature=0.7,
405
- do_sample=False, # Disable sampling for faster generation
406
  early_stopping=True
407
  )
408
 
@@ -414,12 +418,12 @@ def optimize_query(
414
 
415
  except Exception as model_error:
416
  print(f"Model-based expansion failed: {str(model_error)}")
417
- enhanced_query = query # Fallback to original query
418
 
419
  # 4. Combine original and expanded terms
420
  final_terms = set(tokens)
421
  final_terms.update(expanded_terms)
422
- if enhanced_query != query: # Only add if model expansion worked
423
  final_terms.update(word_tokenize(enhanced_query.lower()))
424
 
425
  # 5. Remove stopwords and select top_k most relevant terms
@@ -434,13 +438,15 @@ def optimize_query(
434
  del tokenizer
435
  if device == "cuda":
436
  torch.cuda.empty_cache()
437
- print(expanded_query.strip())
438
  return expanded_query.strip()
439
 
440
  except Exception as e:
441
  print(f"Query optimization failed: {str(e)}")
442
  return query # Return original query if optimization fails
443
 
 
 
444
  # Example usage
445
  """
446
  chunks = ["sample text chunk 1", "sample text chunk 2"]
@@ -843,6 +849,7 @@ def compare_embeddings(file, query, embedding_models, custom_embedding_model, sp
843
  "apply_phonetic": apply_phonetic,
844
  "phonetic_weight": phonetic_weight,
845
  "use_query_optimization": use_query_optimization,
 
846
  "use_reranking": use_reranking
847
  }
848
 
@@ -1337,7 +1344,7 @@ def launch_interface(share=True):
1337
  'apply_phonetic': [False], # Default phonetic settings
1338
  'phonetic_weight': [0.5],
1339
  'custom_separators': [None],
1340
- 'query_optimization_model': ['gpt-3.5-turbo'] # Default query optimization model
1341
  }
1342
 
1343
  # Run automated tests
 
96
  #####
97
  from huggingface_hub import InferenceClient
98
 
99
+ #repo_id = "meta-llama/Llama-3.2-1B-Instruct"
100
 
101
+ #llm = InferenceClient(model=repo_id, timeout=120)
102
 
103
  # Test your LLM client
104
  #llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)
 
108
  resources = ['punkt', 'stopwords', 'snowball_data']
109
  for resource in resources:
110
  try:
111
+ nltk.download(resource, quiet=False)
112
  except Exception as e:
113
  print(f"Failed to download {resource}: {str(e)}")
114
 
 
331
 
332
  def optimize_query(
333
  query: str,
334
+ query_optimization_model: str, # Added to match your signature = "google/flan-t5-small"
335
  chunks: List[str],
336
  embedding_model: str,
337
+ vector_store_type: str, # Added to match your signature
338
+ search_type: str, # Added to match your signature
339
  top_k: int = 3,
340
+ use_gpu: bool = False
 
341
  ) -> str:
342
  """
343
  CPU-optimized version of query expansion using a small language model.
344
 
345
  Args:
346
  query: Original search query
347
+ query_optimization_model: Name or path of the model to use for optimization
348
  chunks: List of text chunks to search through
349
  embedding_model: Name of the embedding model being used
350
+ vector_store_type: Type of vector store being used
351
+ search_type: Type of search being performed
352
  top_k: Number of expansion terms to add
 
353
  use_gpu: Whether to use GPU if available (defaults to False for CPU)
354
 
355
  Returns:
 
371
  # Limit number of lemmas
372
  expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
373
 
374
+ # 3. Use provided model with reduced complexity
375
  try:
376
  # Load model with reduced memory footprint
377
  tokenizer = AutoTokenizer.from_pretrained(
378
+ query_optimization_model, # Use the provided model name
379
+ model_max_length=128,
380
+ cache_dir="./model_cache"
381
  )
382
+ model = AutoModelForSeq2Gen.from_pretrained(
383
+ query_optimization_model, # Use the provided model name
384
+ low_cpu_mem_usage=True,
385
+ device_map="cpu"
386
  )
387
 
388
  # Move model to CPU and eval mode
389
  model = model.to(device)
390
+ model.eval()
391
 
392
  # Prepare input with reduced length
393
  prompt = f"Enhance this search query with relevant terms: {query}"
394
  inputs = tokenizer(
395
  prompt,
396
  return_tensors="pt",
397
+ max_length=64,
398
  truncation=True,
399
  padding=True
400
  )
401
 
402
  # Generate with minimal parameters
403
+ with torch.no_grad():
404
  outputs = model.generate(
405
  inputs.input_ids.to(device),
406
+ max_length=32,
407
  num_return_sequences=1,
408
  temperature=0.7,
409
+ do_sample=False,
410
  early_stopping=True
411
  )
412
 
 
418
 
419
  except Exception as model_error:
420
  print(f"Model-based expansion failed: {str(model_error)}")
421
+ enhanced_query = query
422
 
423
  # 4. Combine original and expanded terms
424
  final_terms = set(tokens)
425
  final_terms.update(expanded_terms)
426
+ if enhanced_query != query:
427
  final_terms.update(word_tokenize(enhanced_query.lower()))
428
 
429
  # 5. Remove stopwords and select top_k most relevant terms
 
438
  del tokenizer
439
  if device == "cuda":
440
  torch.cuda.empty_cache()
441
+
442
  return expanded_query.strip()
443
 
444
  except Exception as e:
445
  print(f"Query optimization failed: {str(e)}")
446
  return query # Return original query if optimization fails
447
 
448
+
449
+
450
  # Example usage
451
  """
452
  chunks = ["sample text chunk 1", "sample text chunk 2"]
 
849
  "apply_phonetic": apply_phonetic,
850
  "phonetic_weight": phonetic_weight,
851
  "use_query_optimization": use_query_optimization,
852
+ "query_optimization_model": query_optimization_model
853
  "use_reranking": use_reranking
854
  }
855
 
 
1344
  'apply_phonetic': [False], # Default phonetic settings
1345
  'phonetic_weight': [0.5],
1346
  'custom_separators': [None],
1347
+ 'query_optimization_model': ['google/flan-t5-base'] # Default query optimization model
1348
  }
1349
 
1350
  # Run automated tests