Spaces:
Sleeping
Sleeping
import json | |
import uvicorn | |
import tensorflow as tf | |
import gpt_2_simple as gpt2 | |
from fastapi import FastAPI, HTTPException, Header | |
import os | |
import requests | |
api_key = os.getenv("api_key") | |
log_url = os.getenv("log") | |
app = FastAPI() | |
# Load GPT-2 model | |
model_name = "chatbot" | |
sess = gpt2.start_tf_sess() # Start a new TensorFlow session | |
gpt2.load_gpt2(sess, model_name=model_name) # Load the GPT-2 model | |
# Define a list of allowed API keys | |
ALLOWED_API_KEYS = [api_key, "your-api-key-1", "your-api-key-2"] | |
async def read_root(): | |
return {"Hello": "World!"} | |
import time | |
async def generate_text(data: dict): | |
api_key = data.get("api_key") | |
if api_key not in ALLOWED_API_KEYS: | |
raise HTTPException(status_code=401, detail="Unauthorized") | |
prompt = data.get("prompt") | |
if not prompt: | |
raise HTTPException(status_code=400, detail="Prompt is required.") | |
start_time = time.time() # Get the start time | |
# Generate text using GPT-2 model | |
generated_text = gpt2.generate(sess, model_name=model_name, prefix=prompt, | |
length=200, | |
temperature=0.65, | |
truncate='<|endoftext|>', | |
return_as_list=True)[0] | |
end_time = time.time() # Get the end time | |
time_taken = end_time - start_time # Calculate the time taken | |
# Log generated text to Python server | |
log_server_url = log_url | |
log_data = { | |
"generated_text": generated_text, | |
"time_taken": time_taken | |
} | |
try: | |
response = requests.post(log_server_url, json=log_data) | |
response.raise_for_status() | |
# Log successful | |
except requests.exceptions.RequestException as e: | |
# Log error if request fails | |
print("Error logging generated text:", e) | |
# Return the generated text along with the time taken | |
return {"generated_text": generated_text, "time_taken": time_taken} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |