|
import gradio as gr |
|
import torch |
|
from transformers import RobertaTokenizerFast, RobertaForTokenClassification, AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import os |
|
from datetime import datetime |
|
import shutil |
|
|
|
log_path = "files/logs.jsonl" |
|
temp_download_path = "files/download_logs.jsonl" |
|
os.makedirs("files", exist_ok=True) |
|
|
|
hf_model_repo = "amirghasemiveisi/surrey-nlp-pg04-best-model" |
|
subfolder = "ner_roberta_lion_model" |
|
|
|
model = RobertaForTokenClassification.from_pretrained(hf_model_repo, subfolder=subfolder) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"roberta-base", |
|
use_fast=True, |
|
add_prefix_space=True |
|
) |
|
|
|
id2label_path = hf_hub_download(repo_id=hf_model_repo, filename="id2label.json", subfolder=subfolder) |
|
with open(id2label_path) as f: |
|
id2label = json.load(f) |
|
|
|
with open("test_samples.json") as f: |
|
test_texts = json.load(f) |
|
|
|
with open("test_labels.json") as f: |
|
test_labels = json.load(f) |
|
|
|
options = [f"{i}: {' '.join(tokens)}" for i, tokens in enumerate(test_texts)] |
|
|
|
def log_results(sentence_idx, sentence, words, pred_ids, true_token_labels, word_ids, feedback=None): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
log = { |
|
"timestamp": timestamp, |
|
"sentence_idx": sentence_idx, |
|
"sentence": sentence, |
|
"input": words, |
|
"prediction": [] |
|
} |
|
|
|
for i, word_id in enumerate(word_ids): |
|
if word_id is None or word_id >= len(words): |
|
continue |
|
|
|
try: |
|
true_idx = [j for j, wid in enumerate(word_ids) if wid == word_id and j < len(true_token_labels)][0] |
|
if true_token_labels[true_idx] == -100: |
|
continue |
|
|
|
pred_label = id2label[str(pred_ids[i])] if i < len(pred_ids) else "N/A" |
|
true_label = id2label[str(true_token_labels[true_idx])] |
|
token = words[word_id] |
|
|
|
log["prediction"].append({ |
|
"token": token, |
|
"predicted": pred_label, |
|
"true": true_label |
|
}) |
|
except: |
|
continue |
|
|
|
if feedback: |
|
log["feedback"] = feedback |
|
|
|
with open(log_path, "a") as f: |
|
f.write(json.dumps(log) + "\n") |
|
|
|
shutil.copy(log_path, temp_download_path) |
|
|
|
def ner_predict_from_selection(selected_display_text): |
|
if not selected_display_text: |
|
return "Please select a test sentence", [], gr.update(value=None) |
|
|
|
try: |
|
index = int(selected_display_text.split(":")[0]) |
|
words = test_texts[index] |
|
true_token_labels = test_labels[index] |
|
|
|
inputs = tokenizer(words, return_tensors="pt", is_split_into_words=True, truncation=True) |
|
word_ids = inputs.word_ids() |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
pred_ids = torch.argmax(logits, dim=2).squeeze().tolist() |
|
|
|
if not isinstance(pred_ids, list): |
|
pred_ids = [pred_ids] |
|
|
|
seen = set() |
|
highlighted_tokens = [] |
|
|
|
for i, word_id in enumerate(word_ids): |
|
if word_id is None or word_id in seen: |
|
continue |
|
seen.add(word_id) |
|
|
|
word = words[word_id] |
|
|
|
true_idx = None |
|
for j, wid in enumerate(word_ids): |
|
if wid == word_id and j < len(true_token_labels): |
|
true_idx = j |
|
break |
|
|
|
if true_idx is None or true_token_labels[true_idx] == -100: |
|
continue |
|
|
|
true_id = true_token_labels[true_idx] |
|
pred_id = pred_ids[i] if i < len(pred_ids) else pred_ids[-1] |
|
|
|
true_label = id2label[str(true_id)] |
|
pred_label = id2label[str(pred_id)] |
|
|
|
if pred_label == true_label: |
|
highlighted_tokens.append((word, pred_label)) |
|
else: |
|
highlighted_tokens.append((word, f"WRONG: {pred_label} β {true_label}")) |
|
|
|
sentence = " ".join(words) |
|
log_results(index, sentence, words, pred_ids, true_token_labels, word_ids) |
|
return sentence, highlighted_tokens, temp_download_path |
|
|
|
except Exception as e: |
|
import traceback |
|
error_details = traceback.format_exc() |
|
print(f"Error in prediction: {error_details}") |
|
return f"Error processing selection: {str(e)}", [], gr.update(value=None) |
|
|
|
def log_feedback(feedback_type, sentence_text): |
|
with open(log_path, "a") as f: |
|
f.write(json.dumps({"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "feedback": feedback_type, "sentence": sentence_text}) + "\n") |
|
shutil.copy(log_path, temp_download_path) |
|
return |
|
|
|
with gr.Blocks(title="RoBERTa NER with LION Optimizer - PG04 Group") as demo: |
|
gr.Markdown("<h1 style='text-align: center; color: #3b3b3b;'>RoBERTa NER with LION Optimizer</h1>") |
|
gr.Markdown("<p style='text-align: center;'>Select a test sentence and see token-level NER predictions. Feedback is logged and downloadable.</p>") |
|
|
|
dropdown = gr.Dropdown(choices=options, label="Select Test Sentence") |
|
submit_btn = gr.Button("Submit", variant="primary", elem_classes="orange-button") |
|
|
|
sentence_output = gr.Text(label="Original Sentence") |
|
prediction_output = gr.HighlightedText(label="Predicted Tags (WRONG if misclassified)") |
|
download_button = gr.File(label="Download Logs", interactive=True) |
|
|
|
with gr.Row(): |
|
like_btn = gr.Button("π Like", elem_classes="green-button") |
|
dislike_btn = gr.Button("π Dislike", elem_classes="red-button") |
|
|
|
submit_btn.click(ner_predict_from_selection, inputs=dropdown, outputs=[sentence_output, prediction_output, download_button]) |
|
like_btn.click(fn=log_feedback, inputs=[gr.Textbox(visible=False, value="like"), sentence_output], outputs=[]) |
|
dislike_btn.click(fn=log_feedback, inputs=[gr.Textbox(visible=False, value="dislike"), sentence_output], outputs=[]) |
|
|
|
demo.css = """ |
|
.orange-button { |
|
background-color: #FFA500 !important; |
|
color: white !important; |
|
font-weight: bold; |
|
} |
|
.green-button { |
|
background-color: #4CAF50 !important; |
|
color: white !important; |
|
font-weight: bold; |
|
} |
|
.red-button { |
|
background-color: #f44336 !important; |
|
color: white !important; |
|
font-weight: bold; |
|
} |
|
""" |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |