amy011872 commited on
Commit
7a1ba7d
1 Parent(s): 64d92c0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -3
handler.py CHANGED
@@ -101,21 +101,25 @@ class EndpointHandler():
101
  if not inputs.endswith("<cite>"):
102
  inputs += "<cite>"
103
  logger.info(inputs)
104
-
105
  inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
106
  with torch.no_grad():
107
  outputs = self.model(**inputs)
108
  outputs_logits = outputs.logits[0, -1, self.law_token_ids]
109
- base_logits = outputs.logits[0, -1, self.law_token_ids]
110
 
 
 
 
 
 
111
  raw_mean = outputs_logits.mean()
112
  outputs_logits = outputs_logits - base_lambda * base_logits
113
  outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
114
 
115
  law_token_probs = outputs_logits.softmax(dim=0)
116
  sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
117
- print([self.law_token_names[x] for x in sorted_ids])
118
  token_objects = [
119
  self.law_lookup.get_law_from_token(self.law_token_names[x])
120
  for x in sorted_ids.tolist()]
 
121
  return {"tokens": token_objects}
 
101
  if not inputs.endswith("<cite>"):
102
  inputs += "<cite>"
103
  logger.info(inputs)
 
104
  inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
105
  with torch.no_grad():
106
  outputs = self.model(**inputs)
107
  outputs_logits = outputs.logits[0, -1, self.law_token_ids]
 
108
 
109
+ base_input = tokenizer("<cite>", return_tensors="pt").to("cuda")
110
+ with torch.no_grad():
111
+ base_output = self.model(**base_input)
112
+
113
+ base_logits = base_output.logits[0, -1, self.law_token_ids]
114
  raw_mean = outputs_logits.mean()
115
  outputs_logits = outputs_logits - base_lambda * base_logits
116
  outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
117
 
118
  law_token_probs = outputs_logits.softmax(dim=0)
119
  sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
120
+ logger.info([self.law_token_names[x] for x in sorted_ids])
121
  token_objects = [
122
  self.law_lookup.get_law_from_token(self.law_token_names[x])
123
  for x in sorted_ids.tolist()]
124
+
125
  return {"tokens": token_objects}