as / app.py
asv7j's picture
Update app.py
6fb2b0a verified
raw
history blame
2.08 kB
import json
import uvicorn
import tensorflow as tf
import gpt_2_simple as gpt2
from fastapi import FastAPI, HTTPException, Header
import os
import requests
api_key = os.getenv("api_key")
log_url = os.getenv("log")
app = FastAPI()
# Load GPT-2 model
model_name = "chatbot"
sess = gpt2.start_tf_sess() # Start a new TensorFlow session
gpt2.load_gpt2(sess, model_name=model_name) # Load the GPT-2 model
# Define a list of allowed API keys
ALLOWED_API_KEYS = [api_key, "your-api-key-1", "your-api-key-2"]
@app.get("/")
async def read_root():
return {"Hello": "World!"}
import time
@app.post("/generate")
async def generate_text(data: dict):
api_key = data.get("api_key")
if api_key not in ALLOWED_API_KEYS:
raise HTTPException(status_code=401, detail="Unauthorized")
prompt = data.get("prompt")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required.")
start_time = time.time() # Get the start time
# Generate text using GPT-2 model
generated_text = gpt2.generate(sess, model_name=model_name, prefix=prompt,
length=200,
temperature=0.65,
truncate='<|endoftext|>',
return_as_list=True)[0]
end_time = time.time() # Get the end time
time_taken = end_time - start_time # Calculate the time taken
# Log generated text to Python server
log_server_url = log_url
log_data = {
"generated_text": generated_text,
"time_taken": time_taken
}
try:
response = requests.post(log_server_url, json=log_data)
response.raise_for_status()
# Log successful
except requests.exceptions.RequestException as e:
# Log error if request fails
print("Error logging generated text:", e)
# Return the generated text along with the time taken
return {"generated_text": generated_text, "time_taken": time_taken}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)