File size: 3,008 Bytes
9d032cb
 
 
8bffd4f
 
 
9d032cb
 
 
 
8bffd4f
 
 
 
 
 
 
 
9d032cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bffd4f
 
 
 
 
 
 
 
 
 
 
 
 
9d032cb
8bffd4f
 
 
9d032cb
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import DebertaV2Tokenizer, DebertaV2ForTokenClassification
import torch
from huggingface_hub import hf_hub_download
import json
from globe import title, description, joinus, model_name, placeholder, modelinfor1, modelinfor2, id2label

tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
model = DebertaV2ForTokenClassification.from_pretrained(model_name)

# # Define id2label based on config.json
# 
# id2label = {
#     0: "author", 1: "bibliography", 2: "caption", 3: "contact", 
#     4: "date", 5: "dialog", 6: "footnote", 7: "keywords", 
#     8: "math", 9: "paratext", 10: "separator", 11: "table", 
#     12: "text", 13: "title"
# }


color_map = {
    "author": "blue", "bibliography": "purple", "caption": "orange", 
    "contact": "cyan", "date": "green", "dialog": "yellow", 
    "footnote": "pink", "keywords": "lightblue", "math": "red", 
    "paratext": "lightgreen", "separator": "gray", "table": "brown", 
    "text": "lightgray", "title": "gold"
}


def segment_text(input_text):

    tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

    with torch.no_grad():
        outputs = model(**tokens)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    tokens_decoded = tokenizer.convert_ids_to_tokens(tokens['input_ids'].squeeze())
    
    segments = []
    current_word = ""
    for token, label_id in zip(tokens_decoded, predictions):
        if token.startswith("▁"):  # handling wordpieces, specific to some tokenizers
            if current_word:
                segments.append((current_word, id2label[label_id]))
            current_word = token.replace("▁", "")  # new word
        else:
            current_word += token  # append subword part to current word
    
    if current_word:
        segments.append((current_word, id2label[label_id]))
    
    return segments

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(title)
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Group():                    
                gr.Markdown(description)
            with gr.Accordion(label="Join Us", open=False):        
                gr.Markdown(joinus)
        with gr.Column(scale=1):
            with gr.Row():
                with gr.Group():
                    gr.Markdown(modelinfor1)
                with gr.Group():
                    gr.Markdown(modelinfor2)
    
    with gr.Row():
        input_text = gr.Textbox(label="Enter your text here👇🏻", lines=5, placeholder=placeholder)
        output_text = gr.HighlightedText(label=" PLeIAs/✂️📜 Segment Text", color_map=color_map, combine_adjacent=True, show_inline_category=True, show_legend=True)
    def process(input_text):
        return segment_text(input_text)
    
    submit_button = gr.Button("Segment Text")
    submit_button.click(fn=process, inputs=input_text, outputs=output_text)

if __name__ == "__main__":
    demo.launch()