bhlewis commited on
Commit
86b10f3
1 Parent(s): f05d8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -92
app.py CHANGED
@@ -3,80 +3,33 @@ import numpy as np
3
  import h5py
4
  import faiss
5
  import json
 
 
 
6
  import re
7
  from collections import Counter
 
8
  import torch
9
- from nltk.corpus import stopwords
10
- from nltk.tokenize import word_tokenize
11
  import nltk
12
- from sentence_transformers import SentenceTransformer
13
- from sklearn.feature_extraction.text import TfidfVectorizer
14
- from sklearn.metrics.pairwise import cosine_similarity
15
-
16
- # Download necessary NLTK data
17
- nltk.download('stopwords', quiet=True)
18
- nltk.download('punkt', quiet=True)
19
-
20
- # Load SentenceTransformer model
21
- model = SentenceTransformer('anferico/bert-for-patents')
22
-
23
- def preprocess_query(text):
24
- # Remove "[EN]" label and claim numbers
25
- text = re.sub(r'\[EN\]\s*', '', text)
26
- text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
27
-
28
- # Convert to lowercase while preserving acronyms and units
29
- words = text.split()
30
- text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
31
-
32
- # Remove special characters except hyphens and periods in numbers
33
- text = re.sub(r'[^\w\s\-.]', ' ', text)
34
- text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
35
-
36
- # Normalize spaces
37
- text = re.sub(r'\s+', ' ', text).strip()
38
-
39
- # Tokenize
40
- tokens = word_tokenize(text)
41
-
42
- # Remove stopwords
43
- stop_words = set(stopwords.words('english'))
44
- tokens = [word for word in tokens if word.lower() not in stop_words]
45
-
46
- # Join tokens back into text
47
- text = ' '.join(tokens)
48
-
49
- # Preserve numerical values with units
50
- text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text)
51
-
52
- # Handle ranges and measurements
53
- text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
54
- text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
55
-
56
- # Preserve chemical formulas
57
- text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
58
-
59
- return text
60
 
61
- def extract_key_features(text):
62
- # For queries, we'll just preprocess and return all non-stopword terms
63
- processed_text = preprocess_query(text)
64
- # Split the processed text into individual terms
65
- features = processed_text.split()
66
- # Remove duplicates while preserving order
67
- features = list(dict.fromkeys(features))
68
- return features
69
 
70
- def encode_texts(texts):
71
- embeddings = model.encode(texts, show_progress_bar=True)
72
- return embeddings
 
 
 
 
73
 
74
  def load_data():
75
  try:
76
  with h5py.File('patent_embeddings.h5', 'r') as f:
77
  embeddings = f['embeddings'][:]
78
  patent_numbers = f['patent_numbers'][:]
79
-
80
  metadata = {}
81
  texts = []
82
  with open('patent_metadata.jsonl', 'r') as f:
@@ -84,15 +37,63 @@ def load_data():
84
  data = json.loads(line)
85
  metadata[data['patent_number']] = data
86
  texts.append(data['text'])
87
-
88
  print(f"Embedding shape: {embeddings.shape}")
89
  print(f"Number of patent numbers: {len(patent_numbers)}")
90
  print(f"Number of metadata entries: {len(metadata)}")
 
91
  return embeddings, patent_numbers, metadata, texts
 
 
 
92
  except Exception as e:
93
- print(f"An error occurred while loading data: {e}")
94
  raise
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def compare_features(query_features, patent_features):
97
  common_features = set(query_features) & set(patent_features)
98
  similarity_score = len(common_features) / max(len(query_features), len(patent_features))
@@ -100,21 +101,21 @@ def compare_features(query_features, patent_features):
100
 
101
  def hybrid_search(query, top_k=5):
102
  print(f"Original query: {query}")
103
- processed_query = preprocess_query(query)
104
- query_features = extract_key_features(processed_query)
105
-
106
- # Encode the processed query using the SentenceTransformer model
107
- query_embedding = encode_texts([processed_query])[0]
108
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
109
-
110
  # Perform semantic similarity search
111
  semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
112
-
113
  # Perform TF-IDF based search
114
- query_tfidf = tfidf_vectorizer.transform([processed_query])
115
  tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
116
  tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
117
-
118
  # Combine and rank results
119
  combined_results = {}
120
  for i, idx in enumerate(semantic_indices[0]):
@@ -127,7 +128,7 @@ def hybrid_search(query, top_k=5):
127
  'common_features': common_features,
128
  'text': text
129
  }
130
-
131
  for idx in tfidf_indices:
132
  patent_number = patent_numbers[idx].decode('utf-8')
133
  if patent_number not in combined_results:
@@ -139,9 +140,10 @@ def hybrid_search(query, top_k=5):
139
  'common_features': common_features,
140
  'text': text
141
  }
142
-
143
  # Sort and get top results
144
  top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
 
145
  results = []
146
  for patent_number, data in top_results:
147
  result = f"Patent Number: {patent_number}\n"
@@ -149,24 +151,10 @@ def hybrid_search(query, top_k=5):
149
  result += f"Combined Score: {data['score']:.4f}\n"
150
  result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
151
  results.append(result)
152
-
153
  return "\n".join(results)
154
 
155
- # Load data and prepare the FAISS index
156
- embeddings, patent_numbers, metadata, texts = load_data()
157
-
158
- # Normalize embeddings for cosine similarity
159
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
160
-
161
- # Create FAISS index for cosine similarity
162
- index = faiss.IndexFlatIP(embeddings.shape[1])
163
- index.add(embeddings)
164
-
165
- # Create TF-IDF vectorizer
166
- tfidf_vectorizer = TfidfVectorizer(stop_words='english')
167
- tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
168
-
169
- # Create Gradio interface
170
  iface = gr.Interface(
171
  fn=hybrid_search,
172
  inputs=[
@@ -179,4 +167,4 @@ iface = gr.Interface(
179
  )
180
 
181
  if __name__ == "__main__":
182
- iface.launch()
 
3
  import h5py
4
  import faiss
5
  import json
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
  import re
10
  from collections import Counter
11
+ import spacy
12
  import torch
13
+ from nltk.corpus import wordnet
 
14
  import nltk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Download WordNet data
17
+ nltk.download('wordnet')
 
 
 
 
 
 
18
 
19
+ # Load Spacy model for advanced NLP
20
+ try:
21
+ nlp = spacy.load("en_core_web_sm")
22
+ except IOError:
23
+ print("Downloading spacy model...")
24
+ spacy.cli.download("en_core_web_sm")
25
+ nlp = spacy.load("en_core_web_sm")
26
 
27
  def load_data():
28
  try:
29
  with h5py.File('patent_embeddings.h5', 'r') as f:
30
  embeddings = f['embeddings'][:]
31
  patent_numbers = f['patent_numbers'][:]
32
+
33
  metadata = {}
34
  texts = []
35
  with open('patent_metadata.jsonl', 'r') as f:
 
37
  data = json.loads(line)
38
  metadata[data['patent_number']] = data
39
  texts.append(data['text'])
40
+
41
  print(f"Embedding shape: {embeddings.shape}")
42
  print(f"Number of patent numbers: {len(patent_numbers)}")
43
  print(f"Number of metadata entries: {len(metadata)}")
44
+
45
  return embeddings, patent_numbers, metadata, texts
46
+ except FileNotFoundError as e:
47
+ print(f"Error: Could not find file. {e}")
48
+ raise
49
  except Exception as e:
50
+ print(f"An unexpected error occurred while loading data: {e}")
51
  raise
52
 
53
+ embeddings, patent_numbers, metadata, texts = load_data()
54
+
55
+ # Load BERT model for encoding search queries
56
+ tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
57
+ bert_model = AutoModel.from_pretrained('anferico/bert-for-patents')
58
+
59
+ def encode_texts(texts, max_length=512):
60
+ inputs = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
61
+ with torch.no_grad():
62
+ outputs = bert_model(**inputs)
63
+ embeddings = outputs.last_hidden_state.mean(dim=1)
64
+ return embeddings.numpy()
65
+
66
+ # Check if the embedding dimensions match
67
+ if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
68
+ print("Embedding dimensions do not match. Rebuilding FAISS index.")
69
+ # Rebuild embeddings using the new model
70
+ embeddings = encode_texts(texts)
71
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
72
+
73
+ # Normalize embeddings for cosine similarity
74
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
75
+
76
+ # Create FAISS index for cosine similarity
77
+ index = faiss.IndexFlatIP(embeddings.shape[1])
78
+ index.add(embeddings)
79
+
80
+ # Create TF-IDF vectorizer
81
+ tfidf_vectorizer = TfidfVectorizer(stop_words='english')
82
+ tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
83
+
84
+ def extract_key_features(text):
85
+ # Use Spacy to extract technical terms and phrases
86
+ doc = nlp(text)
87
+ technical_terms = []
88
+ for token in doc:
89
+ if token.dep_ in ('amod', 'compound') or token.ent_type_ in ('PRODUCT', 'ORG', 'GPE', 'NORP'):
90
+ technical_terms.append(token.text.lower())
91
+ noun_phrases = [chunk.text.lower() for chunk in doc.noun_chunks]
92
+ feature_phrases = [sent.text.lower() for sent in doc.sents if re.search(r'(comprising|including|consisting of|deformable|insulation|heat-resistant|memory foam|high-temperature)', sent.text, re.IGNORECASE)]
93
+
94
+ all_features = technical_terms + noun_phrases + feature_phrases
95
+ return list(set(all_features))
96
+
97
  def compare_features(query_features, patent_features):
98
  common_features = set(query_features) & set(patent_features)
99
  similarity_score = len(common_features) / max(len(query_features), len(patent_features))
 
101
 
102
  def hybrid_search(query, top_k=5):
103
  print(f"Original query: {query}")
104
+
105
+ query_features = extract_key_features(query)
106
+
107
+ # Encode the query using the transformer model
108
+ query_embedding = encode_texts([query])[0]
109
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
110
+
111
  # Perform semantic similarity search
112
  semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
113
+
114
  # Perform TF-IDF based search
115
+ query_tfidf = tfidf_vectorizer.transform([query])
116
  tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
117
  tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
118
+
119
  # Combine and rank results
120
  combined_results = {}
121
  for i, idx in enumerate(semantic_indices[0]):
 
128
  'common_features': common_features,
129
  'text': text
130
  }
131
+
132
  for idx in tfidf_indices:
133
  patent_number = patent_numbers[idx].decode('utf-8')
134
  if patent_number not in combined_results:
 
140
  'common_features': common_features,
141
  'text': text
142
  }
143
+
144
  # Sort and get top results
145
  top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
146
+
147
  results = []
148
  for patent_number, data in top_results:
149
  result = f"Patent Number: {patent_number}\n"
 
151
  result += f"Combined Score: {data['score']:.4f}\n"
152
  result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
153
  results.append(result)
154
+
155
  return "\n".join(results)
156
 
157
+ # Create Gradio interface with additional input fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  iface = gr.Interface(
159
  fn=hybrid_search,
160
  inputs=[
 
167
  )
168
 
169
  if __name__ == "__main__":
170
+ iface.launch()