dhruv3d's picture
Update main.py
ba0b855 verified
raw
history blame
No virus
1.82 kB
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import InferenceClient
app = FastAPI()
# Class for input data for general prompt to model
class UserPrompt(BaseModel):
query: str
# Class for the output data
class OutputData(BaseModel):
response: str
@app.get("/")
async def start():
return "Please go to /docs to try the API endpoints"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
async def generate(prompt, temperature=0.2, max_new_tokens=600, top_p=0.70, repetition_penalty=1.2):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
#seed=42,
)
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
# yield output
return output
"""
Prompt anything to the LLM & get response
"""
@app.post("/generate", response_model=OutputData)
async def generate_AI_response(request: Request, input_data: UserPrompt, q: Union[str, None] = None):
try:
query = input_data.query
if query and query != "" and query != ".":
response = await generate(query)
# return StreamingResponse(generate(query), media_type='text/event-stream')
except ValueError as error:
print(error)
return OutputData(response=response)