AWeirdDev's picture
Update app.py
3c9aed4 verified
import time
import json
from typing import List, Literal
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
app = FastAPI()
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.2"
)
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
class Payload(BaseModel):
stream: bool = False
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2"
messages: List[Message]
temperature: float = 0.9
frequency_penalty: float = 1.2
top_p: float = 0.9
async def stream(iter):
while True:
try:
value = await asyncio.to_thread(iter.__next__)
yield value
except StopIteration:
break
def format_prompt(messages: List[Message]):
prompt = "<s>"
for message in messages:
if message['role'] == 'user':
prompt += f"[INST] {message['content']} [/INST]"
else:
prompt += f" {message['content']}</s> "
return prompt
def make_chunk_obj(i, delta, fr):
return {
"id": str(time.time_ns()),
"object": "chat.completion.chunk",
"created": round(time.time()),
"model": "mistral-7b-instruct-v0.2",
"system_fingerprint": "wtf",
"choices": [
{
"index": i,
"delta": {
"content": delta
},
"finish_reason": fr
}
]
}
def generate(
messages,
temperature=0.9,
max_new_tokens=256,
top_p=0.95,
repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=None
)
formatted_prompt = format_prompt(messages)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
t = response.token.text
yield t if t != "</s>" else ""
#return output
def generate_norm(*args) -> str:
t = ""
for chunk in generate(*args):
t += chunk
return t
@app.get('/')
async def index():
return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" })
@app.post('/chat/completions')
async def c_cmp(payload: Payload):
if not payload.stream:
return JSONResponse(
{
"id": str(time.time_ns()),
"object": "chat.completion",
"created": round(time.time()),
"model": payload.model,
"system_fingerprint": "wtf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": generate_norm(
payload.model_dump()['messages'],
payload.temperature,
4096,
payload.top_p,
payload.frequency_penalty
)
}
}
]
}
)
def streamer():
text = ""
result = generate(
payload.model_dump()['messages'],
payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
payload.frequency_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
)
for i, item in enumerate(result):
yield item
return StreamingResponse(streamer())