import os import csv import gradio from gradio import utils import huggingface_hub from pathlib import Path from src.models.bert import BERTClassifier from src.utils.utilities import Utility model = BERTClassifier(model_name='jeevavijay10/nlp-goemotions-bert') classes = Utility().read_emotion_list() hf_token = os.getenv("HF_TOKEN") dataset_dir = "logs" headers = ["input", "output"] repo = huggingface_hub.Repository( local_dir=dataset_dir, clone_from="https://huggingface.co/datasets/jeevavijay10/senti-pred-gradio", use_auth_token=hf_token ) def log_record(input, output): repo.git_pull(lfs=True) log_file = Path(dataset_dir) / "log.csv" is_new = not Path(log_file).exists() with open(log_file, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) if is_new: writer.writerow(utils.sanitize_list_for_csv(headers)) writer.writerow(utils.sanitize_list_for_csv([input, output])) with open(log_file, "r", encoding="utf-8") as csvfile: line_count = len([None for _ in csv.reader(csvfile)]) - 1 repo.push_to_hub(commit_message=f"Logged sample #{line_count}") def predict(sentence): print(sentence) predictions = model.evaluate([sentence]) print(f"Predictions: {predictions}") output = classes[predictions[0]] log_record(sentence, output) return output gradio.Interface( fn=predict, inputs="text", outputs="text", allow_flagging='auto', flagging_dir='logs', flagging_callback=gradio.SimpleCSVLogger(), ).launch()