|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
|
|
from rapidfuzz import process |
|
|
|
|
|
|
|
|
|
|
|
sentiment_pipeline = pipeline( |
|
|
"sentiment-analysis", |
|
|
model="sreejith8100/indian_output", |
|
|
tokenizer="sreejith8100/indian_output", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
ner_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicNER", use_fast=True) |
|
|
ner_model = AutoModelForTokenClassification.from_pretrained("ai4bharat/IndicNER") |
|
|
ner_pipeline = pipeline( |
|
|
"ner", |
|
|
model=ner_model, |
|
|
tokenizer=ner_tokenizer, |
|
|
aggregation_strategy="simple", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
CANONICAL_ENTITIES = [ |
|
|
"V Abdurahiman / വി അബ്ദുറഹിമാൻ", |
|
|
"P A Mohamed Riyas / പി എ മുഹമ്മദ് റിയാസ്", |
|
|
"P Rajeev / പി രാജീവ്", |
|
|
"Saji Cherian / സജി ചെറിയാൻ", |
|
|
"Roshy Augustine / റോഷി ഓഗസ്റ്റിൻ", |
|
|
"R Bindu / ആർ ബിന്ദു", |
|
|
"A K Saseendran / എ കെ സസീന്ദ്രൻ", |
|
|
"O R Kelu / ഒ ആർ കെലു", |
|
|
"J Chinchurani / ജെ ചിഞ്ചുറാണി", |
|
|
"K N Balagopal / കെ എൻ ബാലഗോപാൽ", |
|
|
"K Krishnankutty / കെ കൃഷ്ണൻകുട്ടി", |
|
|
"Veena George / വീണാ ജോർജ്", |
|
|
"Antony Raju / ആന്റണി രാജു", |
|
|
"K Rajan / കെ രാജൻ", |
|
|
"M B Rajesh / എം ബി രാജേഷ്", |
|
|
"Chittayam Gopakumar / ചിറ്റയം ഗോപകുമാർ", |
|
|
"K Radhakrishnan / കെ രാധാകൃഷ്ണൻ", |
|
|
"Pinarayi Vijayan / പിണറായി വിജയൻ", |
|
|
"V Sivankutty / വി ശിവൻകുട്ടി", |
|
|
"K K Shailaja / കെ കെ ശൈലജ" |
|
|
] |
|
|
|
|
|
def map_entity(entity_text, known_entities=CANONICAL_ENTITIES, threshold=70): |
|
|
match, score, _ = process.extractOne(entity_text, known_entities) |
|
|
if score >= threshold: |
|
|
return match |
|
|
return None |
|
|
|
|
|
|
|
|
label_map = { |
|
|
"LABEL_0": "POSITIVE", |
|
|
"LABEL_1": "NEGATIVE", |
|
|
"LABEL_2": "NEUTRAL" |
|
|
} |
|
|
|
|
|
|
|
|
def predict(sentence): |
|
|
|
|
|
sent_pred = sentiment_pipeline(sentence)[0] |
|
|
human_label = label_map.get(sent_pred["label"], sent_pred["label"]) |
|
|
|
|
|
|
|
|
entities = ner_pipeline(sentence) |
|
|
mapped_entities = [map_entity(ent["word"]) for ent in entities if map_entity(ent["word"])] |
|
|
|
|
|
return { |
|
|
"sentence": sentence, |
|
|
"prediction": human_label, |
|
|
"score": float(sent_pred["score"]), |
|
|
"mapped_entities": list(set(mapped_entities)) |
|
|
} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Textbox(label="Enter a sentence"), |
|
|
outputs=gr.JSON(label="Result"), |
|
|
title="Entity + Sentiment Analysis", |
|
|
description="Upload a sentence in Malayalam/English. The app detects entities and predicts sentiment." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|