Gla-AI4BioMed-Lab commited on
Commit
f61c828
1 Parent(s): 0312a01

Delete utils/.ipynb_checkpoints

Browse files
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py DELETED
@@ -1,66 +0,0 @@
1
- import json
2
- import re
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- class DrugTokenizer:
8
- def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
9
- self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
- self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
- self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
12
- self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
13
- self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
- self.id_to_token = {v: k for k, v in self.vocab.items()}
15
-
16
- def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
17
- with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
18
- vocab = json.load(vocab_file)
19
- with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
20
- special_tokens_raw = json.load(special_tokens_file)
21
-
22
- special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
23
- return vocab, special_tokens
24
-
25
- def encode(self, sequence):
26
- tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
27
- input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
28
- attention_mask = [1] * len(input_ids)
29
- return {
30
- 'input_ids': input_ids,
31
- 'attention_mask': attention_mask
32
- }
33
-
34
- def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
35
- input_ids_list = []
36
- attention_mask_list = []
37
-
38
- for sequence in sequences:
39
- encoded = self.encode(sequence)
40
- input_ids = encoded['input_ids']
41
- attention_mask = encoded['attention_mask']
42
-
43
- if len(input_ids) > max_length:
44
- input_ids = input_ids[:max_length]
45
- attention_mask = attention_mask[:max_length]
46
- elif len(input_ids) < max_length:
47
- pad_length = max_length - len(input_ids)
48
- input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
49
- attention_mask = attention_mask + [0] * pad_length
50
-
51
- input_ids_list.append(input_ids)
52
- attention_mask_list.append(attention_mask)
53
-
54
- return {
55
- 'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
56
- 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
57
- }
58
-
59
- def decode(self, input_ids, skip_special_tokens=False):
60
- tokens = []
61
- for id in input_ids:
62
- if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
63
- continue
64
- tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
65
- sequence = ''.join([f'[{token}]' for token in tokens])
66
- return sequence