Spaces:
Sleeping
Sleeping
from typing import Union | |
from pydantic import BaseModel | |
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from huggingface_hub import InferenceClient | |
app = FastAPI() | |
# Class for input data for general prompt to model | |
class UserPrompt(BaseModel): | |
query: str | |
# Class for the output data | |
class OutputData(BaseModel): | |
response: str | |
async def start(): | |
return "Please go to /docs to try the API endpoints" | |
client = InferenceClient( | |
"mistralai/Mistral-7B-Instruct-v0.1" | |
) | |
async def generate(prompt, temperature=0.2, max_new_tokens=600, top_p=0.70, repetition_penalty=1.2): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
#seed=42, | |
) | |
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
# yield output | |
return output | |
""" | |
Prompt anything to the LLM & get response | |
""" | |
async def generate_AI_response(request: Request, input_data: UserPrompt, q: Union[str, None] = None): | |
try: | |
query = input_data.query | |
if query and query != "" and query != ".": | |
response = await generate(query) | |
# return StreamingResponse(generate(query), media_type='text/event-stream') | |
except ValueError as error: | |
print(error) | |
return OutputData(response=response) | |