Elalimy's picture
Rename main.py to app.py
c4f5264 verified
raw
history blame contribute delete
No virus
1.93 kB
# Import required libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from flask import Flask, request, jsonify, render_template
# Define the GPT2 class
class GPT2:
def __init__(self):
self.model_type = "GPT2"
# Load pre-trained model tokenizer (vocabulary)
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Load pre-trained model (weights)
self.model = GPT2LMHeadModel.from_pretrained("gpt2")
self.model.eval() # Set the model to evaluation mode
# Use GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def predict_next(self, text, k):
# Encode a text inputs
indexed_tokens = self.tokenizer.encode(text)
# Convert indexed tokens in a PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens]).to(self.device)
# Predict all tokens
with torch.no_grad():
outputs = self.model(tokens_tensor)
predictions = outputs[0]
# Get the predicted next sub-word
probs = predictions[0, -1, :]
top_next = [self.tokenizer.decode(i.item()).strip() for i in probs.topk(k).indices]
return top_next
# Initialize the model
gpt2 = GPT2()
# Create the Flask app
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
data = request.json
text = data.get('text')
k = data.get('k', 5) # Default to 5 predictions if not specified
if not text:
return jsonify({"error": "Text input is required"}), 400
try:
predictions = gpt2.predict_next(text, k)
return jsonify({"predictions": predictions})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)