File size: 5,485 Bytes
aa5d766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os  # Import the os module for working with the operating system
from fastapi import FastAPI, HTTPException  # Import necessary modules from FastAPI
from pydantic import BaseModel  # Import BaseModel from pydantic for data validation
from huggingface_hub import InferenceClient  # Import InferenceClient from huggingface_hub
import uvicorn  # Import uvicorn for running the FastAPI application

app = FastAPI()  # Create a FastAPI instance

# Define the primary and fallback models
primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"]

# Define the data model for the request body
class Item(BaseModel):
    input: str = None
    system_prompt: str = None
    system_output: str = None
    history: list = None
    templates: list = None
    temperature: float = 0.0
    max_new_tokens: int = 1048
    top_p: float = 0.15
    repetition_penalty: float = 1.0
    key: str = None

# Function to generate the response JSON
def generate_response_json(item, output, tokens, model_name):
    return {
        "settings": {
            "input": item.input if item.input is not None else "",
            "system prompt": item.system_prompt if item.system_prompt is not None else "",
            "system output": item.system_output if item.system_output is not None else "",
            "temperature": f"{item.temperature}" if item.temperature is not None else "",
            "max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
            "top p": f"{item.top_p}" if item.top_p is not None else "",
            "repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
            "do sample": "True",
            "seed": "42"
        },
        "response": {
            "output": output.strip().lstrip('\n').rstrip('\n').lstrip('<s>').rstrip('</s>').strip(),
            "unstripped": output,
            "tokens": tokens,
            "model": "primary" if model_name == primary else "fallback",
            "name": model_name
        }
    }

# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item = None):
    try:
        if item is None:
            raise HTTPException(status_code=400, detail="JSON body is required.")

        if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
            raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
        
        input_ = ""
        if item.system_prompt != None and item.system_output != None:
            input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
        elif item.system_prompt != None:
            input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
        elif item.system_output != None:
            input_ = f"<s>{item.system_output}</s>"

        if item.templates != None:
            for num, template in enumerate(item.templates, start=1):
                input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
                for i in range(0, len(template), 2):
                    input_ += f"\n<s>[INST] {template[i]} [/INST]"
                    input_ += f"\n{template[i + 1]}</s>"
                input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"

        input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
        if item.history != None:
            for input_, output_ in item.history:
                input_ += f"\n<s>[INST] {input_} [/INST]"
                input_ += f"\n{output_}"
        input_ += f"\n<s>[INST] {item.input} [/INST]"

        temperature = float(item.temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(item.top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=item.max_new_tokens,
            top_p=top_p,
            repetition_penalty=item.repetition_penalty,
            do_sample=True,
            seed=42,
        )

        tokens = 0
        client = InferenceClient(primary)
        stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
        output = ""
        for response in stream:
            tokens += 1
            output += response.token.text
        return generate_response_json(item, output, tokens, primary)

    except HTTPException as http_error:
        raise http_error

    except Exception as e:
        tokens = 0
        error = ""
        
        for model in fallbacks:
            try:
                client = InferenceClient(model)
                stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
                output = ""
                for response in stream:
                    tokens += 1
                    output += response.token.text
                return generate_response_json(item, output, tokens, model)

            except Exception as e:
                error = f"All models failed. {e}" if e else "All models failed."
                continue

        raise HTTPException(status_code=500, detail=error)

    if "KEY" in os.environ:
        if item.key != os.environ["KEY"]:
            raise HTTPException(status_code=401, detail="Valid key is required.")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)