URRO commited on
Commit
aa5d766
1 Parent(s): 27848f1

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +133 -0
main.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # Import the os module for working with the operating system
2
+ from fastapi import FastAPI, HTTPException # Import necessary modules from FastAPI
3
+ from pydantic import BaseModel # Import BaseModel from pydantic for data validation
4
+ from huggingface_hub import InferenceClient # Import InferenceClient from huggingface_hub
5
+ import uvicorn # Import uvicorn for running the FastAPI application
6
+
7
+ app = FastAPI() # Create a FastAPI instance
8
+
9
+ # Define the primary and fallback models
10
+ primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
+ fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"]
12
+
13
+ # Define the data model for the request body
14
+ class Item(BaseModel):
15
+ input: str = None
16
+ system_prompt: str = None
17
+ system_output: str = None
18
+ history: list = None
19
+ templates: list = None
20
+ temperature: float = 0.0
21
+ max_new_tokens: int = 1048
22
+ top_p: float = 0.15
23
+ repetition_penalty: float = 1.0
24
+ key: str = None
25
+
26
+ # Function to generate the response JSON
27
+ def generate_response_json(item, output, tokens, model_name):
28
+ return {
29
+ "settings": {
30
+ "input": item.input if item.input is not None else "",
31
+ "system prompt": item.system_prompt if item.system_prompt is not None else "",
32
+ "system output": item.system_output if item.system_output is not None else "",
33
+ "temperature": f"{item.temperature}" if item.temperature is not None else "",
34
+ "max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
35
+ "top p": f"{item.top_p}" if item.top_p is not None else "",
36
+ "repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
37
+ "do sample": "True",
38
+ "seed": "42"
39
+ },
40
+ "response": {
41
+ "output": output.strip().lstrip('\n').rstrip('\n').lstrip('<s>').rstrip('</s>').strip(),
42
+ "unstripped": output,
43
+ "tokens": tokens,
44
+ "model": "primary" if model_name == primary else "fallback",
45
+ "name": model_name
46
+ }
47
+ }
48
+
49
+ # Endpoint for generating text
50
+ @app.post("/")
51
+ async def generate_text(item: Item = None):
52
+ try:
53
+ if item is None:
54
+ raise HTTPException(status_code=400, detail="JSON body is required.")
55
+
56
+ if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
57
+ raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
58
+
59
+ input_ = ""
60
+ if item.system_prompt != None and item.system_output != None:
61
+ input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
62
+ elif item.system_prompt != None:
63
+ input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
64
+ elif item.system_output != None:
65
+ input_ = f"<s>{item.system_output}</s>"
66
+
67
+ if item.templates != None:
68
+ for num, template in enumerate(item.templates, start=1):
69
+ input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
70
+ for i in range(0, len(template), 2):
71
+ input_ += f"\n<s>[INST] {template[i]} [/INST]"
72
+ input_ += f"\n{template[i + 1]}</s>"
73
+ input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
74
+
75
+ input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
76
+ if item.history != None:
77
+ for input_, output_ in item.history:
78
+ input_ += f"\n<s>[INST] {input_} [/INST]"
79
+ input_ += f"\n{output_}"
80
+ input_ += f"\n<s>[INST] {item.input} [/INST]"
81
+
82
+ temperature = float(item.temperature)
83
+ if temperature < 1e-2:
84
+ temperature = 1e-2
85
+ top_p = float(item.top_p)
86
+
87
+ generate_kwargs = dict(
88
+ temperature=temperature,
89
+ max_new_tokens=item.max_new_tokens,
90
+ top_p=top_p,
91
+ repetition_penalty=item.repetition_penalty,
92
+ do_sample=True,
93
+ seed=42,
94
+ )
95
+
96
+ tokens = 0
97
+ client = InferenceClient(primary)
98
+ stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
99
+ output = ""
100
+ for response in stream:
101
+ tokens += 1
102
+ output += response.token.text
103
+ return generate_response_json(item, output, tokens, primary)
104
+
105
+ except HTTPException as http_error:
106
+ raise http_error
107
+
108
+ except Exception as e:
109
+ tokens = 0
110
+ error = ""
111
+
112
+ for model in fallbacks:
113
+ try:
114
+ client = InferenceClient(model)
115
+ stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
116
+ output = ""
117
+ for response in stream:
118
+ tokens += 1
119
+ output += response.token.text
120
+ return generate_response_json(item, output, tokens, model)
121
+
122
+ except Exception as e:
123
+ error = f"All models failed. {e}" if e else "All models failed."
124
+ continue
125
+
126
+ raise HTTPException(status_code=500, detail=error)
127
+
128
+ if "KEY" in os.environ:
129
+ if item.key != os.environ["KEY"]:
130
+ raise HTTPException(status_code=401, detail="Valid key is required.")
131
+
132
+ if __name__ == "__main__":
133
+ uvicorn.run(app, host="0.0.0.0", port=8000)