Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request, Body, HTTPException, Depends | |
| from fastapi.security import APIKeyHeader | |
| from typing import Optional | |
| from huggingface_hub import InferenceClient | |
| import random | |
| import os | |
| API_URL = os.environ.get("API_URL") | |
| API_KEY = os.environ.get("API_KEY") | |
| MODEL_NAME = os.environ.get("MODEL_NAME") | |
| client = InferenceClient(MODEL_NAME) | |
| app = FastAPI() | |
| security = APIKeyHeader(name="api_key", auto_error=False) | |
| def get_api_key(api_key: Optional[str] = Depends(security)): | |
| if api_key is None or api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized access") | |
| return api_key | |
| def format_prompt(message, history): | |
| prompt = "<s>" | |
| for user_prompt, bot_response in history: | |
| prompt += f"[INST] {user_prompt} [/INST]" | |
| prompt += f" {bot_response}</s> " | |
| prompt += f"[INST] {message} [/INST]" | |
| return prompt | |
| def generate_text( | |
| request: Request, | |
| body: dict = Body(...), | |
| api_key: str = Depends(get_api_key) | |
| ): | |
| try: | |
| prompt = body.get("prompt", "") | |
| system_prompt = body.get("sysPrompt", "") | |
| temperature = body.get("temperature", 0.5) | |
| top_p = body.get("top_p", 0.95) | |
| max_new_tokens = body.get("max_new_tokens",512) | |
| repetition_penalty = body.get("repetition_penalty", 1.0) | |
| history = [] # You might need to handle this based on your actual usage | |
| formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
| stream = client.text_generation( | |
| formatted_prompt, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| seed=random.randint(0, 10**7), | |
| stream=False, | |
| details=False, | |
| return_full_text=False | |
| ) | |
| return {"generated_text": stream} | |
| except HTTPException as e: | |
| raise e | |