Project_final / demo.py
Piyushmryaa's picture
files upload 1
83c2f38 verified
raw
history blame
No virus
1.66 kB
import streamlit as st
import torch
# from datasets import DatasetDict, Dataset
# # Load train, test, and validation JSON files
# train_data = Dataset.from_json('jsonDataTrain.json')
# test_data = Dataset.from_json('jsonDataTest.json')
# validation_data = Dataset.from_json('jsonDataVal.json')
# # Define the features
# features = ['Post', 'defamation', 'hate', 'non-hostile', 'offensive']
labels = ['hate', 'non-hostile', 'defamation', 'offensive']
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
from transformers import BertTokenizer, BertForSequenceClassification
# Load the fine-tuned model and tokenizer
model_name = "fine_tuned_hindi_bert_model"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
# Example input text
input_text = "मैं एक छात्र हूं जो छात्रावास में रहता हूं और दृढ़ संकल्प के साथ अपनी पढ़ाई करता हूं लेकिन मेरा दोस्त मूर्ख है। वह हर समय गेम खेलता है और खाना खाता है।"
input_text1 = st.input_text(input)
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted class
predicted_class = outputs.logits
for i in range(len(predicted_class[0])):
st.write(id2label[i], predicted_class[0][i].item())