navid72m commited on
Commit
c93ff5d
1 Parent(s): 9ae92ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -12
app.py CHANGED
@@ -8,22 +8,43 @@ import io
8
  import requests
9
  import os
10
 
11
- from sklearn.feature_extraction.text import TfidfVectorizer
12
- from sklearn.metrics.pairwise import cosine_similarity
13
-
14
-
15
  my_token = os.getenv('my_repo_token')
16
- def find_most_relevant_context(contexts, question, max_features=10000):
17
- # Vectorize contexts and question with limited features
18
- tfidf_vectorizer = TfidfVectorizer(max_features=max_features)
19
- tfidf_matrix = tfidf_vectorizer.fit_transform([question] + contexts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Compute cosine similarity between question and contexts
22
- similarity_scores = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
 
 
23
 
24
- # Get index of context with highest similarity
25
- most_relevant_index = similarity_scores.argmax()
26
 
 
 
27
  return contexts[most_relevant_index]
28
 
29
 
@@ -33,6 +54,8 @@ def find_most_relevant_context(contexts, question, max_features=10000):
33
 
34
 
35
 
 
 
36
  API_URL = "https://api-inference.huggingface.co/models/google/gemma-7b"
37
  API_URL_2 = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
38
  API_URL_LLMA = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
 
8
  import requests
9
  import os
10
 
11
+ import faiss
12
+ import numpy as np
13
+ from transformers import AutoTokenizer, AutoModel
14
+ import torch
15
  my_token = os.getenv('my_repo_token')
16
+ # Function to get embeddings using a pre-trained model
17
+ def get_embeddings(texts, model_name='sentence-transformers/all-MiniLM-L6-v2'):
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModel.from_pretrained(model_name)
20
+
21
+ inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
25
+
26
+ return embeddings
27
+
28
+ # Function to find the most relevant context using FAISS
29
+ def find_most_relevant_context_faiss(contexts, question, model_name='sentence-transformers/all-MiniLM-L6-v2'):
30
+ # Get embeddings for contexts and question
31
+ all_texts = [question] + contexts
32
+ embeddings = get_embeddings(all_texts, model_name=model_name)
33
+
34
+ # Separate the question embedding and context embeddings
35
+ question_embedding = embeddings[0]
36
+ context_embeddings = embeddings[1:]
37
 
38
+ # Create a FAISS index and add context embeddings
39
+ dimension = context_embeddings.shape[1]
40
+ index = faiss.IndexFlatL2(dimension)
41
+ index.add(context_embeddings)
42
 
43
+ # Search for the nearest neighbor to the question embedding
44
+ _, indices = index.search(question_embedding.reshape(1, -1), 1)
45
 
46
+ # Get the most relevant context
47
+ most_relevant_index = indices[0][0]
48
  return contexts[most_relevant_index]
49
 
50
 
 
54
 
55
 
56
 
57
+
58
+
59
  API_URL = "https://api-inference.huggingface.co/models/google/gemma-7b"
60
  API_URL_2 = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
61
  API_URL_LLMA = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"