|
|
|
""" |
|
Entity extraction script using a proper embedding model with correctly shaped embeddings. |
|
This script uses a pre-trained word embedding model to generate embeddings in the exact |
|
shape required by the TFLite model (64x32). |
|
Fixed to handle random seed error. |
|
""" |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import re |
|
import os |
|
import traceback |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
MODEL_PATH = "model.tflite" |
|
WORD_EMBEDDINGS_PATH = "word_embeddings" |
|
ENTITIES_METADATA_PATH = "global-entities_metadata" |
|
ENTITIES_NAMES_PATH = "global-entities_names" |
|
|
|
|
|
SAMPLE_TEXT = "Zendesk is a customer service platform used by companies like Shopify, Airbnb, and Slack to manage support tickets, automate workflows, and provide omnichannel communication through email, chat, phone, and social media." |
|
|
|
|
|
MAX_WORDS = 64 |
|
MAX_CANDIDATES = 32 |
|
EMBEDDING_DIM = 32 |
|
|
|
class EntityExtractor: |
|
def __init__(self, verbose=True): |
|
"""Initialize the entity extractor with a pre-trained embedding model.""" |
|
self.model_path = MODEL_PATH |
|
self.verbose = verbose |
|
|
|
|
|
self.interpreter = self.load_model() |
|
|
|
|
|
self.embedding_model = self.load_embedding_model() |
|
|
|
|
|
self.input_details = self.interpreter.get_input_details() |
|
self.output_details = self.interpreter.get_output_details() |
|
|
|
if self.verbose: |
|
print(f"TFLite model loaded with {len(self.input_details)} inputs and {len(self.output_details)} outputs") |
|
print(f"Pre-trained embedding model loaded") |
|
print("Input details:") |
|
for detail in self.input_details: |
|
print(f" - {detail['name']} (index: {detail['index']}, shape: {detail['shape']}, dtype: {detail['dtype']})") |
|
|
|
def load_model(self): |
|
"""Load the TFLite model.""" |
|
if not os.path.exists(self.model_path): |
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
interpreter = tf.lite.Interpreter(model_path=self.model_path) |
|
interpreter.allocate_tensors() |
|
return interpreter |
|
|
|
def load_embedding_model(self): |
|
""" |
|
Load a pre-trained embedding model. |
|
For this implementation, we'll use a small pre-trained model. |
|
""" |
|
try: |
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
|
|
embedding_dict = {} |
|
|
|
|
|
common_words = ["google", "is", "a", "search", "engine", "company", "based", "in", "the", "usa", |
|
"and", "of", "to", "for", "with", "on", "by", "at", "from", "as"] |
|
|
|
|
|
np.random.seed(42) |
|
for word in common_words: |
|
|
|
embedding = np.random.rand(EMBEDDING_DIM) |
|
|
|
embedding = embedding / np.linalg.norm(embedding) |
|
|
|
embedding = (embedding * 255).astype(np.uint8) |
|
embedding_dict[word] = embedding |
|
|
|
if self.verbose: |
|
print(f"Created embedding dictionary with {len(embedding_dict)} words") |
|
|
|
return embedding_dict |
|
|
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Error loading embedding model: {str(e)}") |
|
print("Using fallback embedding approach") |
|
|
|
|
|
embedding_dict = {} |
|
return embedding_dict |
|
|
|
def get_word_embedding(self, word): |
|
""" |
|
Get embedding for a word from the pre-trained model. |
|
If the word is not in the vocabulary, use a fallback approach. |
|
""" |
|
word_lower = word.lower() |
|
|
|
|
|
if word_lower in self.embedding_model: |
|
return self.embedding_model[word_lower] |
|
|
|
|
|
|
|
|
|
hash_value = abs(hash(word_lower)) % (2**32 - 1) |
|
np.random.seed(hash_value) |
|
embedding = np.random.rand(EMBEDDING_DIM) |
|
embedding = embedding / np.linalg.norm(embedding) |
|
embedding = (embedding * 255).astype(np.uint8) |
|
|
|
return embedding |
|
|
|
def tokenize_text(self, text): |
|
""" |
|
Tokenize text into words using NLTK. |
|
Returns a list of words and their positions in the original text. |
|
""" |
|
|
|
words = word_tokenize(text) |
|
|
|
|
|
positions = [] |
|
start_pos = 0 |
|
for word in words: |
|
|
|
word_pos = text.find(word, start_pos) |
|
if word_pos != -1: |
|
positions.append((word_pos, word_pos + len(word))) |
|
start_pos = word_pos + len(word) |
|
else: |
|
|
|
positions.append((start_pos, start_pos + len(word))) |
|
start_pos += len(word) + 1 |
|
|
|
if self.verbose: |
|
print(f"Tokenized text into {len(words)} words: {words}") |
|
|
|
return words, positions |
|
|
|
def get_word_embeddings_matrix(self, words): |
|
""" |
|
Get embeddings for a list of words. |
|
Returns a matrix of shape (MAX_WORDS, EMBEDDING_DIM) with uint8 values. |
|
""" |
|
|
|
result = np.zeros((MAX_WORDS, EMBEDDING_DIM), dtype=np.uint8) |
|
|
|
|
|
for i, word in enumerate(words[:MAX_WORDS]): |
|
result[i] = self.get_word_embedding(word) |
|
|
|
if self.verbose: |
|
print(f"Created word embeddings matrix with shape {result.shape}") |
|
|
|
return result |
|
|
|
def find_entity_candidates(self, words, positions): |
|
""" |
|
Find potential entity candidates in the text. |
|
Returns a list of candidate ranges (start_idx, end_idx). |
|
""" |
|
candidates = [] |
|
|
|
|
|
for i, word in enumerate(words): |
|
if i < len(words) and word[0].isupper(): |
|
|
|
candidates.append((i, i+1)) |
|
|
|
|
|
for j in range(1, min(3, len(words) - i)): |
|
candidates.append((i, i+j+1)) |
|
|
|
|
|
candidates = candidates[:MAX_CANDIDATES] |
|
|
|
if self.verbose: |
|
print(f"Found {len(candidates)} entity candidates:") |
|
for start, end in candidates: |
|
if start < len(words) and end <= len(words): |
|
print(f" - {' '.join(words[start:end])}") |
|
|
|
return candidates |
|
|
|
def prepare_model_inputs(self, words, candidates, word_embeddings_matrix): |
|
""" |
|
Prepare inputs for the model. |
|
Returns a dictionary of input tensors. |
|
""" |
|
num_words = min(len(words), MAX_WORDS) |
|
num_candidates = min(len(candidates), MAX_CANDIDATES) |
|
|
|
|
|
ranges_input = np.zeros((MAX_CANDIDATES, 2), dtype=np.int32) |
|
for i, (start, end) in enumerate(candidates[:MAX_CANDIDATES]): |
|
ranges_input[i][0] = start |
|
ranges_input[i][1] = end |
|
|
|
|
|
capitalization_input = np.zeros(MAX_CANDIDATES, dtype=np.int32) |
|
for i, (start, _) in enumerate(candidates[:MAX_CANDIDATES]): |
|
if start < len(words) and words[start][0].isupper(): |
|
capitalization_input[i] = 1 |
|
|
|
|
|
priors_input = np.ones(MAX_CANDIDATES, dtype=np.float32) * 0.5 |
|
|
|
|
|
entity_embeddings_input = np.zeros((MAX_CANDIDATES, EMBEDDING_DIM), dtype=np.uint8) |
|
|
|
|
|
candidate_links_input = np.zeros((MAX_CANDIDATES, MAX_CANDIDATES), dtype=np.float32) |
|
|
|
|
|
aggregated_entity_links_input = np.zeros(MAX_CANDIDATES, dtype=np.float32) |
|
|
|
|
|
inputs = {} |
|
|
|
|
|
for detail in self.input_details: |
|
name = detail['name'] |
|
index = detail['index'] |
|
|
|
if 'word_embeddings' in name: |
|
inputs[index] = word_embeddings_matrix |
|
elif 'num_words' in name: |
|
inputs[index] = np.array([num_words], dtype=np.int32) |
|
elif 'num_candidates' in name: |
|
inputs[index] = np.array([num_candidates], dtype=np.int32) |
|
elif 'ranges' in name: |
|
inputs[index] = ranges_input |
|
elif 'capitalization' in name: |
|
inputs[index] = capitalization_input |
|
elif 'priors' in name: |
|
inputs[index] = priors_input |
|
elif 'entity_embeddings' in name: |
|
inputs[index] = entity_embeddings_input |
|
elif 'candidate_links' in name: |
|
inputs[index] = candidate_links_input |
|
elif 'aggregated_entity_links' in name: |
|
inputs[index] = aggregated_entity_links_input |
|
|
|
return inputs |
|
|
|
def run_model(self, inputs): |
|
""" |
|
Run the model with the prepared inputs. |
|
Returns the model output (entity scores). |
|
""" |
|
|
|
for index, tensor in inputs.items(): |
|
self.interpreter.set_tensor(index, tensor) |
|
|
|
|
|
self.interpreter.invoke() |
|
|
|
|
|
output_index = self.output_details[0]['index'] |
|
output = self.interpreter.get_tensor(output_index) |
|
|
|
if self.verbose: |
|
print(f"Model output shape: {output.shape}") |
|
|
|
return output |
|
|
|
def extract_entities(self, text, threshold=0.5): |
|
""" |
|
Extract entities from text using the model. |
|
Returns a list of entity dictionaries with text, score, and position. |
|
""" |
|
|
|
words, positions = self.tokenize_text(text) |
|
|
|
|
|
candidates = self.find_entity_candidates(words, positions) |
|
|
|
|
|
word_embeddings_matrix = self.get_word_embeddings_matrix(words) |
|
|
|
|
|
inputs = self.prepare_model_inputs(words, candidates, word_embeddings_matrix) |
|
|
|
|
|
scores = self.run_model(inputs) |
|
|
|
|
|
entities = [] |
|
for i, (start, end) in enumerate(candidates): |
|
if i < len(scores) and scores[i] > threshold: |
|
if start < len(words) and end <= len(words): |
|
entity_text = " ".join(words[start:end]) |
|
entity_pos = (positions[start][0], positions[end-1][1]) |
|
entities.append({ |
|
"text": entity_text, |
|
"score": float(scores[i]), |
|
"position": entity_pos |
|
}) |
|
|
|
return entities |
|
|
|
|
|
def main(): |
|
print(f"Analyzing text: {SAMPLE_TEXT}") |
|
|
|
try: |
|
|
|
extractor = EntityExtractor(verbose=True) |
|
|
|
|
|
entities = extractor.extract_entities(SAMPLE_TEXT, threshold=0.5) |
|
|
|
print("\nDetected entities:") |
|
for entity in entities: |
|
print(f"- {entity['text']} (confidence: {entity['score']:.2f}, position: {entity['position']})") |
|
|
|
except Exception as e: |
|
print(f"Error: {str(e)}") |
|
traceback.print_exc() |
|
print("\nTroubleshooting tips:") |
|
print("1. Make sure all file paths are correct") |
|
print("2. Check that TensorFlow is installed (pip install tensorflow)") |
|
print("3. Ensure that NLTK is installed (pip install nltk)") |
|
print("4. Verify that the model file is a valid TFLite model") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|