antenmanuuel commited on
Commit
742ad17
·
verified ·
1 Parent(s): 03de5c7

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. classes.py +4 -1
  2. helpers.py +12 -3
  3. models.py +20 -1
  4. routes/attention.py +31 -12
  5. routes/mask_prediction.py +3 -2
  6. routes/tokenize.py +3 -2
classes.py CHANGED
@@ -1,9 +1,10 @@
1
- from typing import List
2
  from pydantic import BaseModel
3
 
4
  class TokenizeRequest(BaseModel):
5
  text: str
6
  model_name: str = "bert-base-uncased"
 
7
 
8
  class Token(BaseModel):
9
  text: str
@@ -25,6 +26,7 @@ class MaskPredictionRequest(BaseModel):
25
  mask_index: int
26
  model_name: str = "bert-base-uncased"
27
  top_k: int = 10
 
28
 
29
  class MaskPredictionResponse(BaseModel):
30
  predictions: List[WordPrediction]
@@ -33,6 +35,7 @@ class AttentionRequest(BaseModel):
33
  text: str
34
  model_name: str = "bert-base-uncased"
35
  visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
 
36
 
37
  class AttentionHead(BaseModel):
38
  headIndex: int
 
1
+ from typing import List, Optional
2
  from pydantic import BaseModel
3
 
4
  class TokenizeRequest(BaseModel):
5
  text: str
6
  model_name: str = "bert-base-uncased"
7
+ debug: Optional[bool] = False
8
 
9
  class Token(BaseModel):
10
  text: str
 
26
  mask_index: int
27
  model_name: str = "bert-base-uncased"
28
  top_k: int = 10
29
+ debug: Optional[bool] = False
30
 
31
  class MaskPredictionResponse(BaseModel):
32
  predictions: List[WordPrediction]
 
35
  text: str
36
  model_name: str = "bert-base-uncased"
37
  visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
38
+ debug: Optional[bool] = False
39
 
40
  class AttentionHead(BaseModel):
41
  headIndex: int
helpers.py CHANGED
@@ -202,17 +202,26 @@ def map_bert_tokens_to_words(tokens, original_text):
202
  return token_to_word_map
203
 
204
  # Helper function to load models on demand
205
- def get_model_and_tokenizer(model_name):
206
  if model_name not in MODEL_CONFIGS:
207
  raise HTTPException(status_code=400, detail=f"Model {model_name} not supported")
208
 
209
  if model_name not in models:
210
  print(f"Loading {model_name}...")
211
  config = MODEL_CONFIGS[model_name]
212
- models[model_name] = config["model_class"].from_pretrained(model_name)
213
- tokenizers[model_name] = config["tokenizer_class"].from_pretrained(model_name)
 
 
 
 
 
 
 
 
214
  if torch.cuda.is_available():
215
  models[model_name] = models[model_name].cuda()
 
216
  models[model_name].eval()
217
  print(f"Model {model_name} loaded")
218
 
 
202
  return token_to_word_map
203
 
204
  # Helper function to load models on demand
205
+ def get_model_and_tokenizer(model_name, debug=False):
206
  if model_name not in MODEL_CONFIGS:
207
  raise HTTPException(status_code=400, detail=f"Model {model_name} not supported")
208
 
209
  if model_name not in models:
210
  print(f"Loading {model_name}...")
211
  config = MODEL_CONFIGS[model_name]
212
+
213
+ # Check if this is a custom model that requires special loading
214
+ if config["model_class"] == "custom" or model_name == "EdwinXhen/TinyBert_6Layer_MLM":
215
+ # Use the custom model loading function
216
+ tokenizers[model_name], models[model_name] = load_model(model_name, debug)
217
+ else:
218
+ # Standard model loading
219
+ models[model_name] = config["model_class"].from_pretrained(model_name)
220
+ tokenizers[model_name] = config["tokenizer_class"].from_pretrained(model_name)
221
+
222
  if torch.cuda.is_available():
223
  models[model_name] = models[model_name].cuda()
224
+
225
  models[model_name].eval()
226
  print(f"Model {model_name} loaded")
227
 
models.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel, DistilBertForMaskedLM, DistilBertModel
2
  import nltk
3
 
4
 
@@ -30,8 +30,27 @@ MODEL_CONFIGS = {
30
  "model_class": DistilBertForMaskedLM,
31
  "tokenizer_class": AutoTokenizer,
32
  "base_model_class": DistilBertModel
 
 
 
 
 
 
33
  }
34
  }
35
 
36
  models = {}
37
  tokenizers = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel, DistilBertForMaskedLM, DistilBertModel, AutoModelForMaskedLM
2
  import nltk
3
 
4
 
 
30
  "model_class": DistilBertForMaskedLM,
31
  "tokenizer_class": AutoTokenizer,
32
  "base_model_class": DistilBertModel
33
+ },
34
+ "EdwinXhen/TinyBert_6Layer_MLM": {
35
+ "name": "TinyBERT 6 Layer",
36
+ "model_class": "custom",
37
+ "tokenizer_class": AutoTokenizer,
38
+ "base_model_class": BertModel
39
  }
40
  }
41
 
42
  models = {}
43
  tokenizers = {}
44
+
45
+ def load_model(model_type, debug=False):
46
+ if model_type.lower() == "custom" or model_type == "EdwinXhen/TinyBert_6Layer_MLM":
47
+ # Load custom model from Hugging Face repository
48
+ custom_repo = "EdwinXhen/TinyBert_6Layer_MLM"
49
+ if debug:
50
+ print(f"[DEBUG] Loading custom model from HuggingFace repository: {custom_repo}")
51
+ tokenizer = AutoTokenizer.from_pretrained(custom_repo)
52
+ model = AutoModelForMaskedLM.from_pretrained(custom_repo, output_attentions=True)
53
+ return tokenizer, model
54
+ # Handle other models with existing logic
55
+ # This is a placeholder for the existing model loading logic
56
+ return None, None
routes/attention.py CHANGED
@@ -9,11 +9,12 @@ router = APIRouter()
9
  async def get_attention_matrices(request: AttentionRequest):
10
  """Get attention matrices for the input text using the specified model"""
11
  try:
12
- print(f"Processing attention request: text='{request.text}', model={request.model_name}, method={request.visualization_method}")
 
13
 
14
  # First tokenize the text using the same function that the /tokenize endpoint uses
15
  # to ensure consistency
16
- tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name))
17
  tokens = tokenizer_response["tokens"]
18
  print(f"Tokenized into {len(tokens)} tokens")
19
 
@@ -24,17 +25,35 @@ async def get_attention_matrices(request: AttentionRequest):
24
  raise HTTPException(status_code=400, detail=f"Model {model_name} not supported")
25
 
26
  config = MODEL_CONFIGS[model_name]
27
- base_model_class = config["base_model_class"]
28
 
29
- # Check if we already have a base model cached
30
- base_model_key = f"{model_name}_base"
31
- if base_model_key not in models:
32
- print(f"Loading base model {model_name}...")
33
- models[base_model_key] = base_model_class.from_pretrained(model_name, attn_implementation="eager")
34
- if torch.cuda.is_available():
35
- models[base_model_key] = models[base_model_key].cuda()
36
- models[base_model_key].eval()
37
- print(f"Base model {model_name} loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  model = models[base_model_key]
40
  tokenizer = tokenizers[request.model_name]
 
9
  async def get_attention_matrices(request: AttentionRequest):
10
  """Get attention matrices for the input text using the specified model"""
11
  try:
12
+ debug = request.debug if hasattr(request, 'debug') else False
13
+ print(f"Processing attention request: text='{request.text}', model={request.model_name}, method={request.visualization_method}, debug={debug}")
14
 
15
  # First tokenize the text using the same function that the /tokenize endpoint uses
16
  # to ensure consistency
17
+ tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name, debug=debug))
18
  tokens = tokenizer_response["tokens"]
19
  print(f"Tokenized into {len(tokens)} tokens")
20
 
 
25
  raise HTTPException(status_code=400, detail=f"Model {model_name} not supported")
26
 
27
  config = MODEL_CONFIGS[model_name]
 
28
 
29
+ # Handle custom model differently if needed
30
+ if config["model_class"] == "custom":
31
+ # For custom models, we need special handling
32
+ base_model_key = f"{model_name}_base"
33
+ if base_model_key not in models:
34
+ # For TinyBERT, we use the same model with different configuration
35
+ _, tokenizer = get_model_and_tokenizer(model_name, debug)
36
+ custom_repo = "EdwinXhen/TinyBert_6Layer_MLM"
37
+ print(f"Loading base model from {custom_repo} for attention visualization...")
38
+ from transformers import AutoModel
39
+ models[base_model_key] = AutoModel.from_pretrained(custom_repo, attn_implementation="eager", output_attentions=True)
40
+ if torch.cuda.is_available():
41
+ models[base_model_key] = models[base_model_key].cuda()
42
+ models[base_model_key].eval()
43
+ print(f"Base model {model_name} loaded")
44
+ else:
45
+ # Standard model loading
46
+ base_model_class = config["base_model_class"]
47
+
48
+ # Check if we already have a base model cached
49
+ base_model_key = f"{model_name}_base"
50
+ if base_model_key not in models:
51
+ print(f"Loading base model {model_name}...")
52
+ models[base_model_key] = base_model_class.from_pretrained(model_name, attn_implementation="eager")
53
+ if torch.cuda.is_available():
54
+ models[base_model_key] = models[base_model_key].cuda()
55
+ models[base_model_key].eval()
56
+ print(f"Base model {model_name} loaded")
57
 
58
  model = models[base_model_key]
59
  tokenizer = tokenizers[request.model_name]
routes/mask_prediction.py CHANGED
@@ -18,7 +18,8 @@ async def predict_masked_token(request: MaskPredictionRequest, x_token_to_mask:
18
  print(f"Token to mask header: '{x_token_to_mask}'")
19
  print(f"Explicit masked text header: '{x_explicit_masked_text}'")
20
 
21
- model, tokenizer = get_model_and_tokenizer(request.model_name)
 
22
 
23
  # For RoBERTa, use explicit masked text if provided
24
  if "roberta" in request.model_name and x_explicit_masked_text:
@@ -78,7 +79,7 @@ async def predict_masked_token(request: MaskPredictionRequest, x_token_to_mask:
78
  return MaskPredictionResponse(predictions=predictions_list)
79
 
80
  # Get tokens from the original text using the tokenize endpoint for consistency
81
- tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name))
82
  tokens = tokenizer_response["tokens"]
83
 
84
  print(f"Tokenizer response: {len(tokens)} tokens")
 
18
  print(f"Token to mask header: '{x_token_to_mask}'")
19
  print(f"Explicit masked text header: '{x_explicit_masked_text}'")
20
 
21
+ debug = request.debug if hasattr(request, 'debug') else False
22
+ model, tokenizer = get_model_and_tokenizer(request.model_name, debug)
23
 
24
  # For RoBERTa, use explicit masked text if provided
25
  if "roberta" in request.model_name and x_explicit_masked_text:
 
79
  return MaskPredictionResponse(predictions=predictions_list)
80
 
81
  # Get tokens from the original text using the tokenize endpoint for consistency
82
+ tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name, debug=debug))
83
  tokens = tokenizer_response["tokens"]
84
 
85
  print(f"Tokenizer response: {len(tokens)} tokens")
routes/tokenize.py CHANGED
@@ -8,7 +8,8 @@ router = APIRouter()
8
  async def tokenize_text(request: TokenizeRequest):
9
  """Tokenize input text using the specified model's tokenizer"""
10
  try:
11
- _, tokenizer = get_model_and_tokenizer(request.model_name)
 
12
 
13
  # The text might include punctuation - let the tokenizer handle it properly
14
  if "roberta" in request.model_name:
@@ -27,7 +28,7 @@ async def tokenize_text(request: TokenizeRequest):
27
  # Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
28
  tokens = [clean_roberta_token(token) for token in tokens]
29
  else:
30
- # For BERT and DistilBERT, add special tokens and tokenize
31
  text = f"[CLS] {request.text} [SEP]"
32
  tokens = tokenizer.tokenize(text)
33
 
 
8
  async def tokenize_text(request: TokenizeRequest):
9
  """Tokenize input text using the specified model's tokenizer"""
10
  try:
11
+ debug = request.debug if hasattr(request, 'debug') else False
12
+ _, tokenizer = get_model_and_tokenizer(request.model_name, debug)
13
 
14
  # The text might include punctuation - let the tokenizer handle it properly
15
  if "roberta" in request.model_name:
 
28
  # Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
29
  tokens = [clean_roberta_token(token) for token in tokens]
30
  else:
31
+ # For BERT, DistilBERT, and TinyBERT, add special tokens and tokenize
32
  text = f"[CLS] {request.text} [SEP]"
33
  tokens = tokenizer.tokenize(text)
34