Bull / app.py
Thebull's picture
Update app.py
c1829c0 verified
raw
history blame
1.05 kB
from flask import Flask, request, jsonify
from transformers import DistilBertTokenizerFast, TFDistilBertForSequenceClassification
import tensorflow as tf
import numpy as np
import torch
app = Flask(__name__)
device = '/GPU:0' if torch.cuda.is_available() else 'CPU'
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = TFDistilBertForSequenceClassification.from_pretrained(model_name, from_pt=True).signatures['serving_default'].to(device)
session = tf.compat.v1.keras.backend.get_session()
@app.route("/predict", methods=["POST"])
def predict():
data = request.get_json()
input_text = data["input"]
input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs={'input_ids': input_ids}).logits
probabilities = tf.nn.softmax(outputs).numpy()
prediction = np.argmax(probabilities)
return jsonify({"response": prediction})
if __name__ == "__main__":
app.run(debug=True)