File size: 5,485 Bytes
aa5d766 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os # Import the os module for working with the operating system
from fastapi import FastAPI, HTTPException # Import necessary modules from FastAPI
from pydantic import BaseModel # Import BaseModel from pydantic for data validation
from huggingface_hub import InferenceClient # Import InferenceClient from huggingface_hub
import uvicorn # Import uvicorn for running the FastAPI application
app = FastAPI() # Create a FastAPI instance
# Define the primary and fallback models
primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"]
# Define the data model for the request body
class Item(BaseModel):
input: str = None
system_prompt: str = None
system_output: str = None
history: list = None
templates: list = None
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
key: str = None
# Function to generate the response JSON
def generate_response_json(item, output, tokens, model_name):
return {
"settings": {
"input": item.input if item.input is not None else "",
"system prompt": item.system_prompt if item.system_prompt is not None else "",
"system output": item.system_output if item.system_output is not None else "",
"temperature": f"{item.temperature}" if item.temperature is not None else "",
"max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
"top p": f"{item.top_p}" if item.top_p is not None else "",
"repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
"do sample": "True",
"seed": "42"
},
"response": {
"output": output.strip().lstrip('\n').rstrip('\n').lstrip('<s>').rstrip('</s>').strip(),
"unstripped": output,
"tokens": tokens,
"model": "primary" if model_name == primary else "fallback",
"name": model_name
}
}
# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item = None):
try:
if item is None:
raise HTTPException(status_code=400, detail="JSON body is required.")
if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
input_ = ""
if item.system_prompt != None and item.system_output != None:
input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
elif item.system_prompt != None:
input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
elif item.system_output != None:
input_ = f"<s>{item.system_output}</s>"
if item.templates != None:
for num, template in enumerate(item.templates, start=1):
input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
for i in range(0, len(template), 2):
input_ += f"\n<s>[INST] {template[i]} [/INST]"
input_ += f"\n{template[i + 1]}</s>"
input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
if item.history != None:
for input_, output_ in item.history:
input_ += f"\n<s>[INST] {input_} [/INST]"
input_ += f"\n{output_}"
input_ += f"\n<s>[INST] {item.input} [/INST]"
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
)
tokens = 0
client = InferenceClient(primary)
stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
tokens += 1
output += response.token.text
return generate_response_json(item, output, tokens, primary)
except HTTPException as http_error:
raise http_error
except Exception as e:
tokens = 0
error = ""
for model in fallbacks:
try:
client = InferenceClient(model)
stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
tokens += 1
output += response.token.text
return generate_response_json(item, output, tokens, model)
except Exception as e:
error = f"All models failed. {e}" if e else "All models failed."
continue
raise HTTPException(status_code=500, detail=error)
if "KEY" in os.environ:
if item.key != os.environ["KEY"]:
raise HTTPException(status_code=401, detail="Valid key is required.")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |