Spaces:
Runtime error
Runtime error
from flask import Flask, jsonify, request | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# Load the fine-tuned model and tokenizer | |
tokenizer_path = "gpt2" | |
model_path = 'hunthinn/movie_title_gpt2' | |
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) | |
model = GPT2LMHeadModel.from_pretrained(model_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
def infer_title(input): | |
if input: | |
input_text = "Q: " + input + " A:" | |
input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
output = model.generate(input_ids, max_length=50, num_return_sequences=1) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
response = response.split('A:') | |
return response[-1] | |
app = Flask(__name__) | |
def endpoint(): | |
input = request.args.get("input") | |
output = infer_title(input) | |
return jsonify({"output": output}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |