File size: 13,336 Bytes
ff2efcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#!/usr/bin/env python3
"""
Entity extraction script using a proper embedding model with correctly shaped embeddings.
This script uses a pre-trained word embedding model to generate embeddings in the exact
shape required by the TFLite model (64x32).
Fixed to handle random seed error.
"""

import numpy as np
import tensorflow as tf
import re
import os
import traceback
import nltk
from nltk.tokenize import word_tokenize

# Hardcoded paths - these should match your file locations
MODEL_PATH = "model.tflite"
WORD_EMBEDDINGS_PATH = "word_embeddings"  # Not used for embedding, kept for reference
ENTITIES_METADATA_PATH = "global-entities_metadata"
ENTITIES_NAMES_PATH = "global-entities_names"

# Hardcoded sample text
SAMPLE_TEXT = "Zendesk is a customer service platform used by companies like Shopify, Airbnb, and Slack to manage support tickets, automate workflows, and provide omnichannel communication through email, chat, phone, and social media."

# Constants
MAX_WORDS = 64
MAX_CANDIDATES = 32
EMBEDDING_DIM = 32

class EntityExtractor:
    def __init__(self, verbose=True):
        """Initialize the entity extractor with a pre-trained embedding model."""
        self.model_path = MODEL_PATH
        self.verbose = verbose
        
        # Load TFLite model
        self.interpreter = self.load_model()
        
        # Load pre-trained embedding model
        self.embedding_model = self.load_embedding_model()
        
        # Get input and output details
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        
        if self.verbose:
            print(f"TFLite model loaded with {len(self.input_details)} inputs and {len(self.output_details)} outputs")
            print(f"Pre-trained embedding model loaded")
            print("Input details:")
            for detail in self.input_details:
                print(f"  - {detail['name']} (index: {detail['index']}, shape: {detail['shape']}, dtype: {detail['dtype']})")

    def load_model(self):
        """Load the TFLite model."""
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model file not found: {self.model_path}")
            
        interpreter = tf.lite.Interpreter(model_path=self.model_path)
        interpreter.allocate_tensors()
        return interpreter

    def load_embedding_model(self):
        """
        Load a pre-trained embedding model.
        For this implementation, we'll use a small pre-trained model.
        """
        try:
            # Try to download NLTK data if not already present
            try:
                nltk.data.find('tokenizers/punkt')
            except LookupError:
                nltk.download('punkt')
            
            # Create a simple embedding dictionary for demonstration
            embedding_dict = {}
            
            # Add some common words with random embeddings
            common_words = ["google", "is", "a", "search", "engine", "company", "based", "in", "the", "usa", 
                           "and", "of", "to", "for", "with", "on", "by", "at", "from", "as"]
            
            # Create random but consistent embeddings
            np.random.seed(42)  # For reproducibility
            for word in common_words:
                # Create a random embedding vector
                embedding = np.random.rand(EMBEDDING_DIM)
                # Normalize to unit length
                embedding = embedding / np.linalg.norm(embedding)
                # Scale to uint8 range and convert
                embedding = (embedding * 255).astype(np.uint8)
                embedding_dict[word] = embedding
            
            if self.verbose:
                print(f"Created embedding dictionary with {len(embedding_dict)} words")
            
            return embedding_dict
            
        except Exception as e:
            if self.verbose:
                print(f"Error loading embedding model: {str(e)}")
                print("Using fallback embedding approach")
            
            # Fallback to a very simple embedding approach
            embedding_dict = {}
            return embedding_dict

    def get_word_embedding(self, word):
        """
        Get embedding for a word from the pre-trained model.
        If the word is not in the vocabulary, use a fallback approach.
        """
        word_lower = word.lower()
        
        # Try to get embedding from the model
        if word_lower in self.embedding_model:
            return self.embedding_model[word_lower]
        
        # Fallback: create a deterministic embedding based on the word
        # This ensures consistency for unknown words
        # Fix: Ensure the hash value is a valid seed (between 0 and 2**32-1)
        hash_value = abs(hash(word_lower)) % (2**32 - 1)
        np.random.seed(hash_value)
        embedding = np.random.rand(EMBEDDING_DIM)
        embedding = embedding / np.linalg.norm(embedding)
        embedding = (embedding * 255).astype(np.uint8)
        
        return embedding

    def tokenize_text(self, text):
        """
        Tokenize text into words using NLTK.
        Returns a list of words and their positions in the original text.
        """
        # Use NLTK for better tokenization
        words = word_tokenize(text)
        
        # Get positions (approximate since NLTK doesn't return positions)
        positions = []
        start_pos = 0
        for word in words:
            # Find the word in the text starting from the current position
            word_pos = text.find(word, start_pos)
            if word_pos != -1:
                positions.append((word_pos, word_pos + len(word)))
                start_pos = word_pos + len(word)
            else:
                # Fallback if the exact word can't be found
                positions.append((start_pos, start_pos + len(word)))
                start_pos += len(word) + 1
        
        if self.verbose:
            print(f"Tokenized text into {len(words)} words: {words}")
            
        return words, positions

    def get_word_embeddings_matrix(self, words):
        """
        Get embeddings for a list of words.
        Returns a matrix of shape (MAX_WORDS, EMBEDDING_DIM) with uint8 values.
        """
        # Initialize the result matrix with zeros
        result = np.zeros((MAX_WORDS, EMBEDDING_DIM), dtype=np.uint8)
        
        # Fill the matrix with embeddings for each word
        for i, word in enumerate(words[:MAX_WORDS]):
            result[i] = self.get_word_embedding(word)
        
        if self.verbose:
            print(f"Created word embeddings matrix with shape {result.shape}")
        
        return result

    def find_entity_candidates(self, words, positions):
        """
        Find potential entity candidates in the text.
        Returns a list of candidate ranges (start_idx, end_idx).
        """
        candidates = []
        
        # Look for capitalized words as potential entities
        for i, word in enumerate(words):
            if i < len(words) and word[0].isupper():
                # Single word entity
                candidates.append((i, i+1))
                
                # Look for multi-word entities (up to 3 words)
                for j in range(1, min(3, len(words) - i)):
                    candidates.append((i, i+j+1))
        
        # Limit to MAX_CANDIDATES
        candidates = candidates[:MAX_CANDIDATES]
        
        if self.verbose:
            print(f"Found {len(candidates)} entity candidates:")
            for start, end in candidates:
                if start < len(words) and end <= len(words):
                    print(f"  - {' '.join(words[start:end])}")
        
        return candidates

    def prepare_model_inputs(self, words, candidates, word_embeddings_matrix):
        """
        Prepare inputs for the model.
        Returns a dictionary of input tensors.
        """
        num_words = min(len(words), MAX_WORDS)
        num_candidates = min(len(candidates), MAX_CANDIDATES)
        
        # Prepare ranges input
        ranges_input = np.zeros((MAX_CANDIDATES, 2), dtype=np.int32)
        for i, (start, end) in enumerate(candidates[:MAX_CANDIDATES]):
            ranges_input[i][0] = start
            ranges_input[i][1] = end
        
        # Prepare capitalization input (1 if capitalized, 0 otherwise)
        capitalization_input = np.zeros(MAX_CANDIDATES, dtype=np.int32)
        for i, (start, _) in enumerate(candidates[:MAX_CANDIDATES]):
            if start < len(words) and words[start][0].isupper():
                capitalization_input[i] = 1
        
        # Prepare priors input (simplified)
        priors_input = np.ones(MAX_CANDIDATES, dtype=np.float32) * 0.5
        
        # Prepare entity embeddings (simplified)
        entity_embeddings_input = np.zeros((MAX_CANDIDATES, EMBEDDING_DIM), dtype=np.uint8)
        
        # Prepare candidate links (simplified)
        candidate_links_input = np.zeros((MAX_CANDIDATES, MAX_CANDIDATES), dtype=np.float32)
        
        # Prepare aggregated entity links (simplified)
        aggregated_entity_links_input = np.zeros(MAX_CANDIDATES, dtype=np.float32)
        
        # Create input dictionary
        inputs = {}
        
        # Map inputs to the correct input tensor indices
        for detail in self.input_details:
            name = detail['name']
            index = detail['index']
            
            if 'word_embeddings' in name:
                inputs[index] = word_embeddings_matrix
            elif 'num_words' in name:
                inputs[index] = np.array([num_words], dtype=np.int32)
            elif 'num_candidates' in name:
                inputs[index] = np.array([num_candidates], dtype=np.int32)
            elif 'ranges' in name:
                inputs[index] = ranges_input
            elif 'capitalization' in name:
                inputs[index] = capitalization_input
            elif 'priors' in name:
                inputs[index] = priors_input
            elif 'entity_embeddings' in name:
                inputs[index] = entity_embeddings_input
            elif 'candidate_links' in name:
                inputs[index] = candidate_links_input
            elif 'aggregated_entity_links' in name:
                inputs[index] = aggregated_entity_links_input
        
        return inputs

    def run_model(self, inputs):
        """
        Run the model with the prepared inputs.
        Returns the model output (entity scores).
        """
        # Set input tensors
        for index, tensor in inputs.items():
            self.interpreter.set_tensor(index, tensor)
        
        # Run inference
        self.interpreter.invoke()
        
        # Get output tensor
        output_index = self.output_details[0]['index']
        output = self.interpreter.get_tensor(output_index)
        
        if self.verbose:
            print(f"Model output shape: {output.shape}")
            
        return output

    def extract_entities(self, text, threshold=0.5):
        """
        Extract entities from text using the model.
        Returns a list of entity dictionaries with text, score, and position.
        """
        # Tokenize text
        words, positions = self.tokenize_text(text)
        
        # Find entity candidates
        candidates = self.find_entity_candidates(words, positions)
        
        # Get word embeddings matrix with correct shape (64x32)
        word_embeddings_matrix = self.get_word_embeddings_matrix(words)
        
        # Prepare model inputs
        inputs = self.prepare_model_inputs(words, candidates, word_embeddings_matrix)
        
        # Run model
        scores = self.run_model(inputs)
        
        # Process results
        entities = []
        for i, (start, end) in enumerate(candidates):
            if i < len(scores) and scores[i] > threshold:
                if start < len(words) and end <= len(words):
                    entity_text = " ".join(words[start:end])
                    entity_pos = (positions[start][0], positions[end-1][1])
                    entities.append({
                        "text": entity_text,
                        "score": float(scores[i]),
                        "position": entity_pos
                    })
        
        return entities


def main():
    print(f"Analyzing text: {SAMPLE_TEXT}")
    
    try:
        # Create entity extractor with verbose output
        extractor = EntityExtractor(verbose=True)
        
        # Extract entities from the sample text
        entities = extractor.extract_entities(SAMPLE_TEXT, threshold=0.5)
        
        print("\nDetected entities:")
        for entity in entities:
            print(f"- {entity['text']} (confidence: {entity['score']:.2f}, position: {entity['position']})")
    
    except Exception as e:
        print(f"Error: {str(e)}")
        traceback.print_exc()
        print("\nTroubleshooting tips:")
        print("1. Make sure all file paths are correct")
        print("2. Check that TensorFlow is installed (pip install tensorflow)")
        print("3. Ensure that NLTK is installed (pip install nltk)")
        print("4. Verify that the model file is a valid TFLite model")


if __name__ == "__main__":
    main()