sreejith8100's picture
Update app.py
10f884a verified
import torch
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from rapidfuzz import process
# ------------------ Load models once ------------------
# Sentiment model
sentiment_pipeline = pipeline(
"sentiment-analysis",
model="sreejith8100/indian_output",
tokenizer="sreejith8100/indian_output",
device=0 if torch.cuda.is_available() else -1
)
# NER model
ner_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicNER", use_fast=True)
ner_model = AutoModelForTokenClassification.from_pretrained("ai4bharat/IndicNER")
ner_pipeline = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
device=0 if torch.cuda.is_available() else -1
)
# Canonical entity list
CANONICAL_ENTITIES = [
"V Abdurahiman / വി അബ്ദുറഹിമാൻ",
"P A Mohamed Riyas / പി എ മുഹമ്മദ് റിയാസ്",
"P Rajeev / പി രാജീവ്",
"Saji Cherian / സജി ചെറിയാൻ",
"Roshy Augustine / റോഷി ഓഗസ്റ്റിൻ",
"R Bindu / ആർ ബിന്ദു",
"A K Saseendran / എ കെ സസീന്ദ്രൻ",
"O R Kelu / ഒ ആർ കെലു",
"J Chinchurani / ജെ ചിഞ്ചുറാണി",
"K N Balagopal / കെ എൻ ബാലഗോപാൽ",
"K Krishnankutty / കെ കൃഷ്ണൻകുട്ടി",
"Veena George / വീണാ ജോർജ്",
"Antony Raju / ആന്റണി രാജു",
"K Rajan / കെ രാജൻ",
"M B Rajesh / എം ബി രാജേഷ്",
"Chittayam Gopakumar / ചിറ്റയം ഗോപകുമാർ",
"K Radhakrishnan / കെ രാധാകൃഷ്ണൻ",
"Pinarayi Vijayan / പിണറായി വിജയൻ",
"V Sivankutty / വി ശിവൻകുട്ടി",
"K K Shailaja / കെ കെ ശൈലജ"
]
def map_entity(entity_text, known_entities=CANONICAL_ENTITIES, threshold=70):
match, score, _ = process.extractOne(entity_text, known_entities)
if score >= threshold:
return match
return None
# Map raw model labels to readable ones
label_map = {
"LABEL_0": "POSITIVE",
"LABEL_1": "NEGATIVE",
"LABEL_2": "NEUTRAL"
}
# ------------------ Prediction function ------------------
def predict(sentence):
# Run sentiment
sent_pred = sentiment_pipeline(sentence)[0]
human_label = label_map.get(sent_pred["label"], sent_pred["label"]) # map it
# Run NER + map
entities = ner_pipeline(sentence)
mapped_entities = [map_entity(ent["word"]) for ent in entities if map_entity(ent["word"])]
return {
"sentence": sentence,
"prediction": human_label, # use mapped label
"score": float(sent_pred["score"]),
"mapped_entities": list(set(mapped_entities))
}
# ------------------ Gradio Interface ------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Enter a sentence"),
outputs=gr.JSON(label="Result"),
title="Entity + Sentiment Analysis",
description="Upload a sentence in Malayalam/English. The app detects entities and predicts sentiment."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)