omkar56's picture
Update main.py
704842e
raw
history blame contribute delete
No virus
1.91 kB
from fastapi import FastAPI, Request, Body, HTTPException, Depends
from fastapi.security import APIKeyHeader
from typing import Optional
from huggingface_hub import InferenceClient
import random
import os
API_URL = os.environ.get("API_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME")
client = InferenceClient(MODEL_NAME)
app = FastAPI()
security = APIKeyHeader(name="api_key", auto_error=False)
def get_api_key(api_key: Optional[str] = Depends(security)):
if api_key is None or api_key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized access")
return api_key
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
@app.post("/api/v1/generate_text", response_model=dict)
def generate_text(
request: Request,
body: dict = Body(...),
api_key: str = Depends(get_api_key)
):
prompt = body.get("prompt", "")
sys_prompt = body.get("sysPrompt", "")
temperature = body.get("temperature", 0.5)
top_p = body.get("top_p", 0.95)
max_new_tokens = body.get("max_new_tokens",512)
repetition_penalty = body.get("repetition_penalty", 1.0)
print(f"temperature + {temperature}")
history = [] # You might need to handle this based on your actual usage
formatted_prompt = format_prompt(f"{sys_prompt}, {prompt}", history)
stream = client.text_generation(
formatted_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
stream=False,
details=False,
return_full_text=False
)
return {"generated_text": stream}