Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -267,41 +267,18 @@ re_model = BertRE(num_labels=len(rel2id))
|
|
| 267 |
re_model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
| 268 |
re_model.eval()
|
| 269 |
|
| 270 |
-
'''
|
| 271 |
-
def entities_and_types(sentence):
|
| 272 |
-
ner_output = extract(sentence)
|
| 273 |
-
entities = distill_entities(ner_output)
|
| 274 |
-
|
| 275 |
-
entity_dict = {}
|
| 276 |
-
for name, entity_type, _, _ in entities:
|
| 277 |
-
entity_dict[name] = entity_type
|
| 278 |
|
| 279 |
-
return entity_dict
|
| 280 |
-
'''
|
| 281 |
def convert_ner_format(ner_output):
|
| 282 |
return [[item["token"], item["tags"]] for item in ner_output]
|
| 283 |
|
| 284 |
def entities_and_types(sentence):
|
| 285 |
-
print("\n=== NER DEBUG ===")
|
| 286 |
-
print("INPUT:", sentence)
|
| 287 |
-
|
| 288 |
ner_output = extract(sentence)
|
| 289 |
-
print("RAW NER OUTPUT:", ner_output)
|
| 290 |
-
|
| 291 |
-
# ✅ FIX HERE
|
| 292 |
converted = convert_ner_format(ner_output)
|
| 293 |
-
|
| 294 |
-
print("CONVERTED FORMAT:", converted)
|
| 295 |
-
|
| 296 |
entities = distill_entities(converted)
|
| 297 |
-
|
| 298 |
-
print("DISTILLED ENTITIES:", entities)
|
| 299 |
-
|
| 300 |
entity_dict = {}
|
| 301 |
for name, entity_type, _, _ in entities:
|
| 302 |
entity_dict[name] = entity_type
|
| 303 |
|
| 304 |
-
print("ENTITY DICT:", entity_dict)
|
| 305 |
return entity_dict
|
| 306 |
|
| 307 |
relation_domain_range=[
|
|
@@ -514,8 +491,6 @@ for rel in relation_domain_range:
|
|
| 514 |
for r in rel["range"]:
|
| 515 |
relation_lookup[d][r].append(rel["relation"])
|
| 516 |
|
| 517 |
-
|
| 518 |
-
'''
|
| 519 |
def insert_markers(sentence, ent1, ent2):
|
| 520 |
if ent1 not in sentence or ent2 not in sentence:
|
| 521 |
return None
|
|
@@ -525,23 +500,7 @@ def insert_markers(sentence, ent1, ent2):
|
|
| 525 |
marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
|
| 526 |
|
| 527 |
return marked
|
| 528 |
-
'''
|
| 529 |
-
def insert_markers(sentence, ent1, ent2):
|
| 530 |
-
print("\n--- MARKER DEBUG ---")
|
| 531 |
-
print("Original sentence:", sentence)
|
| 532 |
-
print("Entity1:", ent1)
|
| 533 |
-
print("Entity2:", ent2)
|
| 534 |
-
|
| 535 |
-
if ent1 not in sentence or ent2 not in sentence:
|
| 536 |
-
print("❌ Entity not found in sentence")
|
| 537 |
-
return None
|
| 538 |
|
| 539 |
-
marked = sentence
|
| 540 |
-
marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
|
| 541 |
-
marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
|
| 542 |
-
|
| 543 |
-
print("Marked sentence:", marked)
|
| 544 |
-
return marked
|
| 545 |
def encode(sentence):
|
| 546 |
enc = relation_tokenizer(
|
| 547 |
sentence,
|
|
@@ -603,7 +562,18 @@ def relation_extractor(sentence):
|
|
| 603 |
continue
|
| 604 |
|
| 605 |
if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels:
|
| 606 |
-
output.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
return output
|
| 609 |
|
|
|
|
| 267 |
re_model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
| 268 |
re_model.eval()
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
|
|
|
|
|
|
| 271 |
def convert_ner_format(ner_output):
|
| 272 |
return [[item["token"], item["tags"]] for item in ner_output]
|
| 273 |
|
| 274 |
def entities_and_types(sentence):
|
|
|
|
|
|
|
|
|
|
| 275 |
ner_output = extract(sentence)
|
|
|
|
|
|
|
|
|
|
| 276 |
converted = convert_ner_format(ner_output)
|
|
|
|
|
|
|
|
|
|
| 277 |
entities = distill_entities(converted)
|
|
|
|
|
|
|
|
|
|
| 278 |
entity_dict = {}
|
| 279 |
for name, entity_type, _, _ in entities:
|
| 280 |
entity_dict[name] = entity_type
|
| 281 |
|
|
|
|
| 282 |
return entity_dict
|
| 283 |
|
| 284 |
relation_domain_range=[
|
|
|
|
| 491 |
for r in rel["range"]:
|
| 492 |
relation_lookup[d][r].append(rel["relation"])
|
| 493 |
|
|
|
|
|
|
|
| 494 |
def insert_markers(sentence, ent1, ent2):
|
| 495 |
if ent1 not in sentence or ent2 not in sentence:
|
| 496 |
return None
|
|
|
|
| 500 |
marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
|
| 501 |
|
| 502 |
return marked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
def encode(sentence):
|
| 505 |
enc = relation_tokenizer(
|
| 506 |
sentence,
|
|
|
|
| 562 |
continue
|
| 563 |
|
| 564 |
if conf > 0.80 and rel != "no_relation" and rel.split(".")[-1] in valid_rels:
|
| 565 |
+
output.append({
|
| 566 |
+
"Subject": {
|
| 567 |
+
"Type": type1,
|
| 568 |
+
"Label": ent1
|
| 569 |
+
},
|
| 570 |
+
"Relation": rel,
|
| 571 |
+
"Object": {
|
| 572 |
+
"Type": type2,
|
| 573 |
+
"Label": ent2
|
| 574 |
+
},
|
| 575 |
+
"Confidence": float(round(conf, 4))
|
| 576 |
+
})
|
| 577 |
|
| 578 |
return output
|
| 579 |
|