Spaces:
Sleeping
Sleeping
# Import required libraries | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
from flask import Flask, request, jsonify, render_template | |
# Define the GPT2 class | |
class GPT2: | |
def __init__(self): | |
self.model_type = "GPT2" | |
# Load pre-trained model tokenizer (vocabulary) | |
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
# Load pre-trained model (weights) | |
self.model = GPT2LMHeadModel.from_pretrained("gpt2") | |
self.model.eval() # Set the model to evaluation mode | |
# Use GPU if available | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
def predict_next(self, text, k): | |
# Encode a text inputs | |
indexed_tokens = self.tokenizer.encode(text) | |
# Convert indexed tokens in a PyTorch tensor | |
tokens_tensor = torch.tensor([indexed_tokens]).to(self.device) | |
# Predict all tokens | |
with torch.no_grad(): | |
outputs = self.model(tokens_tensor) | |
predictions = outputs[0] | |
# Get the predicted next sub-word | |
probs = predictions[0, -1, :] | |
top_next = [self.tokenizer.decode(i.item()).strip() for i in probs.topk(k).indices] | |
return top_next | |
# Initialize the model | |
gpt2 = GPT2() | |
# Create the Flask app | |
app = Flask(__name__) | |
def index(): | |
return render_template('index.html') | |
def predict(): | |
data = request.json | |
text = data.get('text') | |
k = data.get('k', 5) # Default to 5 predictions if not specified | |
if not text: | |
return jsonify({"error": "Text input is required"}), 400 | |
try: | |
predictions = gpt2.predict_next(text, k) | |
return jsonify({"predictions": predictions}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True) | |