dhruv3d's picture
Update main.py
1349a51 verified
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import InferenceClient
app = FastAPI()
# Class for input data for general prompt to model
class UserPrompt(BaseModel):
query: str
# Class for the output data
class OutputData(BaseModel):
response: str
@app.get("/")
async def start():
return "Please go to /docs to try the API endpoints"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
async def generate(prompt, temperature=0.7, max_new_tokens=1200, top_p=0.80, repetition_penalty=1.2):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
#seed=42,
)
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
# yield output
return output
"""
Prompt anything to the LLM & get response
"""
@app.post("/generation", response_model=OutputData)
async def generate_AI_response(request: Request, input_data: UserPrompt, q: Union[str, None] = None):
try:
query = input_data.query
if query and query != "" and query != ".":
response = await generate(query)
# return StreamingResponse(generate(query), media_type='text/event-stream')
except ValueError as error:
print(error)
return OutputData(response=response)