Spaces:
Runtime error
Runtime error
File size: 1,833 Bytes
46e6114 a5cbba4 46e6114 a5cbba4 01f65eb 46e6114 01f65eb ce6c474 01f65eb 46e6114 a5cbba4 46e6114 5ee5c89 46e6114 4e52c2b 46e6114 a5cbba4 70c0af2 4e52c2b 46e6114 65376e9 a5cbba4 46e6114 4e52c2b 46e6114 760ecf9 46e6114 760ecf9 46e6114 a5cbba4 760ecf9 46e6114 a5cbba4 01f65eb 46e6114 a5cbba4 46e6114 01f65eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import time
import csv
import datetime
import gradio
from gradio import utils
import huggingface_hub
from pathlib import Path
from src.models.bert import BERTClassifier
from src.utils.utilities import Utility
model = BERTClassifier(model_name='jeevavijay10/nlp-goemotions-bert')
classes = Utility().read_emotion_list()
hf_token = os.getenv("HF_TOKEN")
dataset_dir = "logs"
headers = ["input", "output", "timestamp", "elapsed"]
repo = huggingface_hub.Repository(
local_dir=dataset_dir,
clone_from="https://huggingface.co/datasets/jeevavijay10/senti-pred-gradio",
token=hf_token,
)
repo.git_pull(lfs=False)
def log_record(vals):
log_file = Path(dataset_dir) / "data.csv"
is_new = not Path(log_file).exists()
with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(vals))
with open(log_file, "r", encoding="utf-8") as csvfile:
line_count = len([None for _ in csv.reader(csvfile)]) - 1
repo.push_to_hub(commit_message=f"Logged sample #{line_count}")
def predict(sentence):
print(sentence)
timestamp = datetime.datetime.now().isoformat()
start_time = time.time()
predictions = model.evaluate([sentence])
elapsed_time = time.time() - start_time
print(f"Predictions: {predictions}")
output = classes[predictions[0]]
log_record([sentence, output, timestamp, str(elapsed_time)])
return output
gradio.Interface(
fn=predict,
inputs="text",
outputs="text",
allow_flagging='auto',
flagging_dir='logs',
flagging_callback=gradio.SimpleCSVLogger(),
).launch()
|