import os # Import the os module for working with the operating system
from fastapi import FastAPI, HTTPException # Import necessary modules from FastAPI
from pydantic import BaseModel # Import BaseModel from pydantic for data validation
from huggingface_hub import InferenceClient # Import InferenceClient from huggingface_hub
import uvicorn # Import uvicorn for running the FastAPI application
app = FastAPI() # Create a FastAPI instance
# Define the primary and fallback models
primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"]
# Define the data model for the request body
class Item(BaseModel):
input: str = None
system_prompt: str = None
system_output: str = None
history: list = None
templates: list = None
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
key: str = None
# Function to generate the response JSON
def generate_response_json(item, output, tokens, model_name):
return {
"settings": {
"input": item.input if item.input is not None else "",
"system prompt": item.system_prompt if item.system_prompt is not None else "",
"system output": item.system_output if item.system_output is not None else "",
"temperature": f"{item.temperature}" if item.temperature is not None else "",
"max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
"top p": f"{item.top_p}" if item.top_p is not None else "",
"repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
"do sample": "True",
"seed": "42"
},
"response": {
"output": output.strip().lstrip('\n').rstrip('\n').lstrip('').rstrip('').strip(),
"unstripped": output,
"tokens": tokens,
"model": "primary" if model_name == primary else "fallback",
"name": model_name
}
}
# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item = None):
try:
if item is None:
raise HTTPException(status_code=400, detail="JSON body is required.")
if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
input_ = ""
if item.system_prompt != None and item.system_output != None:
input_ = f"[INST] {item.system_prompt} [/INST] {item.system_output}"
elif item.system_prompt != None:
input_ = f"[INST] {item.system_prompt} [/INST]"
elif item.system_output != None:
input_ = f"{item.system_output}"
if item.templates != None:
for num, template in enumerate(item.templates, start=1):
input_ += f"\n[INST] Beginning of archived conversation {num} [/INST]"
for i in range(0, len(template), 2):
input_ += f"\n[INST] {template[i]} [/INST]"
input_ += f"\n{template[i + 1]}"
input_ += f"\n[INST] End of archived conversation {num} [/INST]"
input_ += f"\n[INST] Beginning of active conversation [/INST]"
if item.history != None:
for input_, output_ in item.history:
input_ += f"\n[INST] {input_} [/INST]"
input_ += f"\n{output_}"
input_ += f"\n[INST] {item.input} [/INST]"
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
)
tokens = 0
client = InferenceClient(primary)
stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
tokens += 1
output += response.token.text
return generate_response_json(item, output, tokens, primary)
except HTTPException as http_error:
raise http_error
except Exception as e:
tokens = 0
error = ""
for model in fallbacks:
try:
client = InferenceClient(model)
stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
tokens += 1
output += response.token.text
return generate_response_json(item, output, tokens, model)
except Exception as e:
error = f"All models failed. {e}" if e else "All models failed."
continue
raise HTTPException(status_code=500, detail=error)
if "KEY" in os.environ:
if item.key != os.environ["KEY"]:
raise HTTPException(status_code=401, detail="Valid key is required.")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)