dejanseo commited on
Commit
9a43936
1 Parent(s): f81fa54

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -27
handler.py CHANGED
@@ -7,47 +7,59 @@ class EndpointHandler:
7
  # Load the configuration from the saved model
8
  self.config = AutoConfig.from_pretrained(path)
9
 
10
- # Make sure to specify the correct model name for bert-large-cased
11
- # Adjust num_labels according to your model's configuration
12
  self.model = BertForTokenClassification.from_pretrained(
13
  path,
14
  config=self.config
15
  )
16
  self.model.eval() # Set model to evaluation mode
17
 
18
- # Load the tokenizer for bert-large-cased
19
  self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
- # Extract input text from the request
23
  inputs = data.get("inputs", "")
24
 
25
- # Tokenize the inputs
26
- inputs_tensor = self.tokenizer(inputs, return_tensors="pt", add_special_tokens=True)
27
- input_ids = inputs_tensor["input_ids"]
 
 
 
 
 
28
 
29
- # Run the model
30
- with torch.no_grad():
31
- outputs = self.model(input_ids)
32
- predictions = torch.argmax(outputs.logits, dim=-1)
33
 
34
- # Process the predictions to generate readable output
35
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
36
- predictions = predictions[0][1:-1].tolist()
37
 
38
- # Reconstruct the text with annotations for token classification
39
- result = []
40
- for token, pred in zip(tokens, predictions):
41
- if pred == 1: # Adjust this based on your classification needs
42
- result.append(f"<u>{token}</u>")
43
- else:
44
- result.append(token)
 
 
 
45
 
46
- reconstructed_text = " ".join(result).replace(" ##", "")
 
 
 
47
 
48
  # Return the processed text in a structured format
49
- return [{"text": reconstructed_text}]
50
-
51
- # Note: Ensure the path "dejanseo/LinkBERT" is correctly pointing to your model's location
52
- # If the model is locally saved, adjust the path accordingly
53
-
 
7
  # Load the configuration from the saved model
8
  self.config = AutoConfig.from_pretrained(path)
9
 
 
 
10
  self.model = BertForTokenClassification.from_pretrained(
11
  path,
12
  config=self.config
13
  )
14
  self.model.eval() # Set model to evaluation mode
15
 
 
16
  self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
17
 
18
+ def split_into_chunks(self, text: str, max_length: int = 510) -> List[str]:
19
+ """
20
+ Splits the input text into manageable chunks for the tokenizer.
21
+ """
22
+ tokens = self.tokenizer.tokenize(text)
23
+ chunk_texts = []
24
+ for i in range(0, len(tokens), max_length):
25
+ chunk = tokens[i:i+max_length]
26
+ chunk_texts.append(self.tokenizer.convert_tokens_to_string(chunk))
27
+ return chunk_texts
28
+
29
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
30
  inputs = data.get("inputs", "")
31
 
32
+ # Split input text into chunks
33
+ chunks = self.split_into_chunks(inputs)
34
+
35
+ all_results = [] # List to store results from each chunk
36
+
37
+ for chunk in chunks:
38
+ inputs_tensor = self.tokenizer(chunk, return_tensors="pt", add_special_tokens=True)
39
+ input_ids = inputs_tensor["input_ids"]
40
 
41
+ with torch.no_grad():
42
+ outputs = self.model(input_ids)
43
+ predictions = torch.argmax(outputs.logits, dim=-1)
 
44
 
45
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
46
+ predictions = predictions[0][1:-1].tolist()
 
47
 
48
+ # Improved reconstruction to handle "##" artifacts
49
+ reconstructed_text = ""
50
+ for token, pred in zip(tokens, predictions):
51
+ if not token.startswith("##"):
52
+ reconstructed_text += " " + token if reconstructed_text else token
53
+ else:
54
+ reconstructed_text += token[2:] # Remove "##" and append
55
+
56
+ if pred == 1: # Example condition, adjust as needed
57
+ reconstructed_text = reconstructed_text.strip() + "<u>" + token + "</u>"
58
 
59
+ all_results.append(reconstructed_text.strip())
60
+
61
+ # Join the results from each chunk
62
+ final_text = " ".join(all_results)
63
 
64
  # Return the processed text in a structured format
65
+ return [{"text": final_text}]