amirghasemiveisi's picture
Upload app.py
b64bc74 verified
import gradio as gr
import torch
from transformers import RobertaTokenizerFast, RobertaForTokenClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import json
import os
from datetime import datetime
import shutil
log_path = "files/logs.jsonl"
temp_download_path = "files/download_logs.jsonl"
os.makedirs("files", exist_ok=True)
hf_model_repo = "amirghasemiveisi/surrey-nlp-pg04-best-model"
subfolder = "ner_roberta_lion_model"
model = RobertaForTokenClassification.from_pretrained(hf_model_repo, subfolder=subfolder)
tokenizer = AutoTokenizer.from_pretrained(
"roberta-base",
use_fast=True,
add_prefix_space=True
)
id2label_path = hf_hub_download(repo_id=hf_model_repo, filename="id2label.json", subfolder=subfolder)
with open(id2label_path) as f:
id2label = json.load(f)
with open("test_samples.json") as f:
test_texts = json.load(f)
with open("test_labels.json") as f:
test_labels = json.load(f)
options = [f"{i}: {' '.join(tokens)}" for i, tokens in enumerate(test_texts)]
def log_results(sentence_idx, sentence, words, pred_ids, true_token_labels, word_ids, feedback=None):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log = {
"timestamp": timestamp,
"sentence_idx": sentence_idx,
"sentence": sentence,
"input": words,
"prediction": []
}
for i, word_id in enumerate(word_ids):
if word_id is None or word_id >= len(words):
continue
try:
true_idx = [j for j, wid in enumerate(word_ids) if wid == word_id and j < len(true_token_labels)][0]
if true_token_labels[true_idx] == -100:
continue
pred_label = id2label[str(pred_ids[i])] if i < len(pred_ids) else "N/A"
true_label = id2label[str(true_token_labels[true_idx])]
token = words[word_id]
log["prediction"].append({
"token": token,
"predicted": pred_label,
"true": true_label
})
except:
continue
if feedback:
log["feedback"] = feedback
with open(log_path, "a") as f:
f.write(json.dumps(log) + "\n")
shutil.copy(log_path, temp_download_path)
def ner_predict_from_selection(selected_display_text):
if not selected_display_text:
return "Please select a test sentence", [], gr.update(value=None)
try:
index = int(selected_display_text.split(":")[0])
words = test_texts[index]
true_token_labels = test_labels[index]
inputs = tokenizer(words, return_tensors="pt", is_split_into_words=True, truncation=True)
word_ids = inputs.word_ids()
with torch.no_grad():
logits = model(**inputs).logits
pred_ids = torch.argmax(logits, dim=2).squeeze().tolist()
if not isinstance(pred_ids, list):
pred_ids = [pred_ids]
seen = set()
highlighted_tokens = []
for i, word_id in enumerate(word_ids):
if word_id is None or word_id in seen:
continue
seen.add(word_id)
word = words[word_id]
true_idx = None
for j, wid in enumerate(word_ids):
if wid == word_id and j < len(true_token_labels):
true_idx = j
break
if true_idx is None or true_token_labels[true_idx] == -100:
continue
true_id = true_token_labels[true_idx]
pred_id = pred_ids[i] if i < len(pred_ids) else pred_ids[-1]
true_label = id2label[str(true_id)]
pred_label = id2label[str(pred_id)]
if pred_label == true_label:
highlighted_tokens.append((word, pred_label))
else:
highlighted_tokens.append((word, f"WRONG: {pred_label} β‰  {true_label}"))
sentence = " ".join(words)
log_results(index, sentence, words, pred_ids, true_token_labels, word_ids)
return sentence, highlighted_tokens, temp_download_path
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in prediction: {error_details}")
return f"Error processing selection: {str(e)}", [], gr.update(value=None)
def log_feedback(feedback_type, sentence_text):
with open(log_path, "a") as f:
f.write(json.dumps({"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "feedback": feedback_type, "sentence": sentence_text}) + "\n")
shutil.copy(log_path, temp_download_path)
return
with gr.Blocks(title="RoBERTa NER with LION Optimizer - PG04 Group") as demo:
gr.Markdown("<h1 style='text-align: center; color: #3b3b3b;'>RoBERTa NER with LION Optimizer</h1>")
gr.Markdown("<p style='text-align: center;'>Select a test sentence and see token-level NER predictions. Feedback is logged and downloadable.</p>")
dropdown = gr.Dropdown(choices=options, label="Select Test Sentence")
submit_btn = gr.Button("Submit", variant="primary", elem_classes="orange-button")
sentence_output = gr.Text(label="Original Sentence")
prediction_output = gr.HighlightedText(label="Predicted Tags (WRONG if misclassified)")
download_button = gr.File(label="Download Logs", interactive=True)
with gr.Row():
like_btn = gr.Button("πŸ‘ Like", elem_classes="green-button")
dislike_btn = gr.Button("πŸ‘Ž Dislike", elem_classes="red-button")
submit_btn.click(ner_predict_from_selection, inputs=dropdown, outputs=[sentence_output, prediction_output, download_button])
like_btn.click(fn=log_feedback, inputs=[gr.Textbox(visible=False, value="like"), sentence_output], outputs=[])
dislike_btn.click(fn=log_feedback, inputs=[gr.Textbox(visible=False, value="dislike"), sentence_output], outputs=[])
demo.css = """
.orange-button {
background-color: #FFA500 !important;
color: white !important;
font-weight: bold;
}
.green-button {
background-color: #4CAF50 !important;
color: white !important;
font-weight: bold;
}
.red-button {
background-color: #f44336 !important;
color: white !important;
font-weight: bold;
}
"""
if __name__ == "__main__":
demo.launch(share=False)