cleopatro commited on
Commit
b1f80ab
1 Parent(s): 7e70798

added main and stm

Browse files
Files changed (8) hide show
  1. config.json +34 -0
  2. main.py +60 -0
  3. model.safetensors +3 -0
  4. special_tokens_map.json +7 -0
  5. stm.py +49 -0
  6. tokenizer.json +0 -0
  7. tokenizer_config.json +55 -0
  8. vocab.txt +0 -0
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForTokenClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "O",
13
+ "1": "B-ABS",
14
+ "2": "I-ABS"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "label2id": {
18
+ "B-ABS": 1,
19
+ "I-ABS": 2,
20
+ "O": 0
21
+ },
22
+ "max_position_embeddings": 512,
23
+ "model_type": "distilbert",
24
+ "n_heads": 12,
25
+ "n_layers": 6,
26
+ "pad_token_id": 0,
27
+ "qa_dropout": 0.1,
28
+ "seq_classif_dropout": 0.2,
29
+ "sinusoidal_pos_embds": false,
30
+ "tie_weights_": true,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.38.2",
33
+ "vocab_size": 30522
34
+ }
main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import spacy
3
+ from spacy import displacy
4
+ from dotenv import load_dotenv
5
+ import os
6
+ from stm import ShortTermMemory
7
+
8
+ load_dotenv()
9
+ api_key = os.getenv("API_KEY")
10
+
11
+ API_URL = "https://api-inference.huggingface.co/models/cleopatro/Entity_Rec"
12
+ headers = {"Authorization": f"Bearer {api_key}"}
13
+ NER = spacy.load("en_core_web_sm")
14
+
15
+ def extract_word_and_entity_group(dict):
16
+ words = []
17
+ result = []
18
+
19
+ for item in dict:
20
+ word = item['word']
21
+ words.append(word)
22
+
23
+ return words
24
+
25
+
26
+ def get_abs(payload):
27
+ response = requests.post(API_URL, headers=headers, json=payload)
28
+ return response.json()
29
+
30
+
31
+ def get_loc_time(sentence):
32
+ text1 = NER(sentence)
33
+ locations = []
34
+ times = []
35
+ for ent in text1.ents:
36
+ if ent.label_ == "GPE" or ent.label_ == "LOC":
37
+ locations.append(ent.text)
38
+ elif ent.label_ == "TIME" or ent.label_ == "DATE":
39
+ times.append(ent.text)
40
+ return locations, times
41
+
42
+
43
+ def get_ent(sentence):
44
+ abs_dict = get_abs(sentence)
45
+ abs_tags = extract_word_and_entity_group(abs_dict)
46
+ loc_tags, time_tags = get_loc_time(sentence["inputs"])
47
+ return abs_tags, loc_tags, time_tags
48
+
49
+
50
+
51
+ # output = get_ent({
52
+ # "inputs": "today stock prices and home loans are a pain in san fransisco.",
53
+ # })
54
+
55
+ # print(output)
56
+
57
+ # stm = ShortTermMemory(window_size=5, decay_rate=0.8)
58
+
59
+ # stm.update('abstract', 'credit-card')
60
+ # print(stm.get_memory()) # Output: {'abstract_entities': {'credit-card': 1}, 'locations': {}, 'times': {}}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16d964bd2589ac2a8af7b5b2afe2e38a8c09346ace42dbe036e8315d9dcd59e1
3
+ size 265473092
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
stm.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ class ShortTermMemory:
4
+ def __init__(self, window_size=10, decay_rate=0.5):
5
+ self.abstract_entities = defaultdict(int)
6
+ self.locations = defaultdict(int)
7
+ self.times = defaultdict(int)
8
+ self.window_size = window_size
9
+ self.decay_rate = decay_rate
10
+
11
+ def update(self, entity_type, entity):
12
+ # Determine the appropriate dictionary based on the entity type
13
+ if entity_type == 'abstract':
14
+ entity_dict = self.abstract_entities
15
+ elif entity_type == 'location':
16
+ entity_dict = self.locations
17
+ elif entity_type == 'time':
18
+ entity_dict = self.times
19
+ else:
20
+ raise ValueError(f'Invalid entity type: {entity_type}')
21
+
22
+ # Increment the count for the given entity
23
+ entity_dict[entity] += 1
24
+
25
+ # Decay the counts of other entities in the same dictionary
26
+ for e, count in list(entity_dict.items()):
27
+ if e != entity:
28
+ entity_dict[e] = int(count * self.decay_rate)
29
+
30
+ # Remove entities with count <= 1
31
+ entity_dict = {e: count for e, count in entity_dict.items() if count > 1}
32
+
33
+ # Trim the dictionary to the window size
34
+ entity_dict = dict(sorted(entity_dict.items(), key=lambda x: x[1], reverse=True)[:self.window_size])
35
+
36
+ # Update the appropriate dictionary with the trimmed version
37
+ if entity_type == 'abstract':
38
+ self.abstract_entities = entity_dict
39
+ elif entity_type == 'location':
40
+ self.locations = entity_dict
41
+ elif entity_type == 'time':
42
+ self.times = entity_dict
43
+
44
+ def get_memory(self):
45
+ return {
46
+ 'abstract_entities': self.abstract_entities,
47
+ 'locations': self.locations,
48
+ 'times': self.times
49
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "DistilBertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff