DavidFM43 commited on
Commit
678fb1e
1 Parent(s): 0e95eb5

Change gen_entities function

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -18,14 +18,20 @@ model = PeftModel.from_pretrained(model, peft_model_id)
18
 
19
  model.eval()
20
 
21
- def gen_entities(text):
22
- text = f"<SP> text: {text}\n\n entities: "
 
 
23
  batch = tokenizer(text, return_tensors="pt")
 
24
  with torch.cuda.amp.autocast():
25
  output_tokens = model.generate(**batch, max_new_tokens=256, eos_token_id=50258)
26
 
27
- # return tokenizer.decode(output_tokens, skip_special_tokens=False)
28
- return tokenizer.batch_decode(output_tokens.detach().cpu().numpy(), skip_special_tokens=True)
 
 
 
29
 
30
 
31
  iface = gr.Interface(fn=gen_entities, inputs="text", outputs="text")
 
18
 
19
  model.eval()
20
 
21
+
22
+ def gen_entities(in_text):
23
+ """Does Named Entity Recognition in the given text."""
24
+ text = f"<SP> text: {in_text}\n\n entities:"
25
  batch = tokenizer(text, return_tensors="pt")
26
+ batch["input_ids"] = batch["input_ids"].to("cuda")
27
  with torch.cuda.amp.autocast():
28
  output_tokens = model.generate(**batch, max_new_tokens=256, eos_token_id=50258)
29
 
30
+ response = tokenizer.batch_decode(
31
+ output_tokens.detach().cpu().numpy(), skip_special_tokens=False
32
+ )[0]
33
+
34
+ return response[response.find("entities") : response.find("<EP>")]
35
 
36
 
37
  iface = gr.Interface(fn=gen_entities, inputs="text", outputs="text")