import gradio as gr from fastapi import FastAPI from pydantic import BaseModel from contextlib import asynccontextmanager from dotenv import load_dotenv from llm import create_chain from store import create_store,get_retreiver,save_retreiver # load_dotenv() # @asynccontextmanager # async def lifespan(app: FastAPI): # # Load the ML model # global chain,store # # store=create_store() # # chain=create_chain(store.as_retriever()) # yield # # Clean up the ML models and release the resources app = FastAPI() @app.on_event("startup") async def startup(): global chain,store store=create_store() chain=create_chain(store.as_retriever()) class Request(BaseModel): prompt : str class Response(BaseModel): response : str def greet(message,history): return message @app.post("/predict") async def invoke_api(message,history="Abc"): return chain.invoke(message) @app.post("/test",response_model=Response) async def predict_api(prompt:Request): response = greet(Request.prompt) return response async def invoke(message,history): return chain.invoke(message) @app.get("/save_retreiver") async def save_retreiver_api(): save_retreiver(store) demo = gr.ChatInterface( fn=invoke, title="LLM App", undo_btn="Delete Previous", clear_btn="Clear", ) app = gr.mount_gradio_app(app, demo, path="/") # if __name__ == "__main__": # # mounting at the root path # uvicorn.run( # app="main:app", # host="localhost",#os.getenv("UVICORN_HOST"), # port=8000#int(os.getenv("UVICORN_PORT")) # )