omkar56's picture
updated main.py to give response in JSON format
ab3abd8
raw
history blame
1.9 kB
from fastapi import FastAPI, Request, Body, HTTPException, Depends
from fastapi.security import APIKeyHeader
from typing import Optional
from huggingface_hub import InferenceClient
import random
import os
API_URL = os.environ.get("API_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME")
client = InferenceClient(MODEL_NAME)
app = FastAPI()
security = APIKeyHeader(name="api_key", auto_error=False)
def get_api_key(api_key: Optional[str] = Depends(security)):
if api_key is None or api_key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized access")
return api_key
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
@app.post("/api/v1/generate_text", response_model=dict)
def generate_text(
request: Request,
body: dict = Body(...),
api_key: str = Depends(get_api_key)
):
prompt = body.get("prompt", "")
sys_prompt = body.get("sysPrompt", "")
temperature = body.get("temperature", 0.5)
top_p = body.get("top_p", 0.95)
max_new_tokens = body.get("max_new_tokens",512)
repetition_penalty = body.get("repetition_penalty", 1.0)
print(f"temperature + {temperature}")
history = [] # You might need to handle this based on your actual usage
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(
formatted_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
stream=False,
details=False,
return_full_text=False
)
return {"generated_text": stream}