Spaces:
Sleeping
Sleeping
# Apache Software License 2.0 | |
# | |
# Copyright (c) ZenML GmbH 2023. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import click | |
import numpy as np | |
import os | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from os.path import dirname | |
import gradio as gr | |
def sentiment_analysis( | |
tokenizer_name_or_path, model_name_or_path, labels, title, description, interpretation, examples | |
): | |
labels = labels.split(",") | |
examples = [examples] | |
def preprocess(text): | |
new_text = [] | |
for t in text.split(" "): | |
t = "@user" if t.startswith("@") and len(t) > 1 else t | |
t = "http" if t.startswith("http") else t | |
new_text.append(t) | |
return " ".join(new_text) | |
def softmax(x): | |
e_x = np.exp(x - np.max(x)) | |
return e_x / e_x.sum(axis=0) | |
def analyze_text(text): | |
model_path = f"{dirname(__file__)}/{model_name_or_path}/" | |
print(f"Loading model from {model_path}") | |
tokenizer_path = f"{dirname(__file__)}/{tokenizer_name_or_path}/" | |
print(f"Loading tokenizer from {tokenizer_path}") | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
text = preprocess(text) | |
encoded_input = tokenizer(text, return_tensors="pt") | |
output = model(**encoded_input) | |
scores_ = output[0][0].detach().numpy() | |
scores_ = softmax(scores_) | |
scores = {l: float(s) for (l, s) in zip(labels, scores_)} | |
return scores | |
demo = gr.Interface( | |
fn=analyze_text, | |
inputs=[gr.TextArea("Write your text or tweet here", label="Analyze Text")], | |
outputs=["label"], | |
title=title, | |
description=description, | |
interpretation=interpretation, | |
examples=examples, | |
) | |
demo.launch(share=True, debug=True) | |
if __name__ == "__main__": | |
sentiment_analysis() | |