DavidFM43 commited on
Commit
fe4bf72
1 Parent(s): e8e65ad

Add gen_entities function

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -11,16 +11,20 @@ model = AutoModelForCausalLM.from_pretrained(
11
  load_in_8bit=True,
12
  device_map="auto",
13
  revision="half",
14
- # low_cpu_mem_usage=True
15
  )
16
  tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
17
  # Load the Lora model
18
  model = PeftModel.from_pretrained(model, peft_model_id)
19
 
20
 
21
- def greet(name):
22
- return "Hello " + name + "!!"
 
 
 
23
 
 
24
 
25
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
26
  iface.launch()
 
11
  load_in_8bit=True,
12
  device_map="auto",
13
  revision="half",
 
14
  )
15
  tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
16
  # Load the Lora model
17
  model = PeftModel.from_pretrained(model, peft_model_id)
18
 
19
 
20
+ def gen_entities(text):
21
+ text = f"<SP> text: {text}\n\n entities: "
22
+ batch = tokenizer(text, return_tensors="pt")
23
+ with torch.cuda.amp.autocast():
24
+ output_tokens = model.generate(**batch, max_new_tokens=256, eos_token_id=50258)
25
 
26
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=False)
27
 
28
+
29
+ iface = gr.Interface(fn=gen_entities, inputs="text", outputs="text")
30
  iface.launch()