Spaces:
Sleeping
Sleeping
File size: 1,826 Bytes
3d289c5 1349a51 3d289c5 ba0b855 3d289c5 1349a51 3d289c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import InferenceClient
app = FastAPI()
# Class for input data for general prompt to model
class UserPrompt(BaseModel):
query: str
# Class for the output data
class OutputData(BaseModel):
response: str
@app.get("/")
async def start():
return "Please go to /docs to try the API endpoints"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
async def generate(prompt, temperature=0.7, max_new_tokens=1200, top_p=0.80, repetition_penalty=1.2):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
#seed=42,
)
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
# yield output
return output
"""
Prompt anything to the LLM & get response
"""
@app.post("/generation", response_model=OutputData)
async def generate_AI_response(request: Request, input_data: UserPrompt, q: Union[str, None] = None):
try:
query = input_data.query
if query and query != "" and query != ".":
response = await generate(query)
# return StreamingResponse(generate(query), media_type='text/event-stream')
except ValueError as error:
print(error)
return OutputData(response=response)
|