reichenbach's picture
App Changes
0464579
raw
history blame
2.35 kB
import os
os.system('pip install tensorflow')
import json
import numpy as np
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from huggingface_hub.keras_mixin import from_pretrained_keras
class CustomNonPaddingTokenLoss(keras.losses.Loss):
def __init__(self, name="custom_ner_loss"):
super().__init__(name=name)
def call(self, y_true, y_pred):
loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fn(y_true, y_pred)
mask = tf.cast((y_true > 0), dtype=tf.float32)
loss = loss * mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
def lowercase_and_convert_to_ids(tokens):
tokens = tf.strings.lower(tokens)
return lookup_layer(tokens)
def tokenize_and_convert_to_ids(text):
tokens = text.split()
return lowercase_and_convert_to_ids(tokens)
def ner_tagging(text_1):
with open("vocab.json",'r') as f:
vocab = json.load(f)
with open('mapping.json','r') as f:
mapping = json.load(f)
ner_model = from_pretrained_keras("keras-io/ner-with-transformers",
custom_objects={'CustomNonPaddingTokenLoss':CustomNonPaddingTokenLoss},
compile=False)
lookup_layer = keras.layers.StringLookup(vocabulary=vocab['tokens'])
sample_input = tokenize_and_convert_to_ids(text_1)
sample_input = tf.reshape(sample_input, shape=[1, -1])
output = ner_model.predict(sample_input)
prediction = np.argmax(output, axis=-1)[0]
prediction = [mapping[str(i)] for i in prediction]
return prediction
text_1 = gr.inputs.Textbox(lines=5)
ner_tag = gr.outputs.Textbox()
iface = gr.Interface(ner_tagging,
inputs=text_1,outputs=ner_tag, examples=[['EU rejects German call to boycott British lamb .'],
["Wednesday's U.S. Open draw ceremony revealed that both title holders should run into their first serious opposition in the third round."]], title="Named Entity Recognition with Transformers",
description = "Named Entity Recognition with Transformers on CoNLL2003 Dataset",
article = "Author: <a href=\"https://huggingface.co/reichenbach\">Rishav Chandra Varma</a>")
iface.launch()