subu4444's picture
Update app.py
4ae8795
raw
history blame contribute delete
No virus
3.75 kB
import gradio as gr
import json
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from transformers import AutoModelForTokenClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# QnA Model (Context based)
max_seq_length = 512
q_n_a_model_name = "deepset/roberta-base-squad2"
q_n_a_model = AutoModelForQuestionAnswering.from_pretrained(q_n_a_model_name)
q_n_a_tokenizer = AutoTokenizer.from_pretrained(q_n_a_model_name)
classification_model_name = "distilbert-base-uncased"
classification_tokenizer = DistilBertTokenizer.from_pretrained(classification_model_name)
classification_model = DistilBertForSequenceClassification.from_pretrained(classification_model_name)
context = gr.Textbox(label="Add the Context (Paragraph or texts) for which you want to get insights", lines=10, outputs="text")
def q_n_a_fn(context, text):
QA_input = {'question': text, 'context': context}
# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_n_a_model.to(device)
# Convert inputs to tensors
inputs = q_n_a_tokenizer(QA_input["context"], QA_input["question"], return_tensors="pt", max_length=max_seq_length, truncation=True, padding="max_length").to(device)
# Get predictions
with torch.no_grad():
outputs = q_n_a_model(**inputs) # Use q_n_a_model to get model predictions
# Get the predicted answer span indices
start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
# Ensure indices are within bounds
start_idx = min(start_idx, len(inputs["input_ids"][0]) - 1)
end_idx = min(end_idx, len(inputs["input_ids"][0]) - 1)
# Find the answer tokens in the input
answer_tokens = inputs["input_ids"][0][start_idx : end_idx + 1]
# Decode the answer tokens into a human-readable answer
answer = q_n_a_tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx+1], skip_special_tokens=True)
return answer
def classification_fn(context):
inputs = classification_tokenizer(context, return_tensors="pt")
with torch.no_grad():
logits = classification_model(**inputs).logits
class_probabilities = torch.softmax(logits, dim=1)
class_probabilities = torch.softmax(logits, dim=1)
class_probabilities = class_probabilities[0].tolist() # Convert to a Python list
return {"POSITIVE": class_probabilities[0], "NEGATIVE": class_probabilities[1]}
def translate_fn(context, text):
return context
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("<h1>Basic NLP Operations</h1>")
gr.Markdown("Bringing basic NLP operations together.")
with gr.Tab("Question and Answer"):
with gr.Row():
gr.Interface(fn=q_n_a_fn, inputs=[context, gr.Textbox(label="Ask question", lines=1)], outputs="text")
with gr.Tab("Classifier"):
with gr.Row():
gr.Interface(fn=classification_fn, inputs=[context], outputs=[gr.Label()])
with gr.Tab("Translation"):
with gr.Row():
gr.Interface(fn=translate_fn, inputs=[context, gr.Radio(["French", "Hindi", "Spanish"], label="Languages", info="Select language")], outputs="text")
with gr.Tab("Summarization"):
with gr.Row():
gr.Interface(fn=classification_fn, inputs=[context], outputs="label")
with gr.Tab("Text To Speech"):
with gr.Row():
gr.Interface(fn=classification_fn, inputs=[context], outputs="audio")
with gr.Tab("Text To Text"):
with gr.Row():
gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
if __name__ == "__main__":
demo.launch()