Spaces:
Running
on
T4
Running
on
T4
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain_google_genai import ( | |
ChatGoogleGenerativeAI, | |
HarmBlockThreshold, | |
HarmCategory, | |
) | |
from TextGen import app | |
class Generate(BaseModel): | |
text:str | |
def generate_text(prompt: str): | |
if prompt == "": | |
return {"detail": "Please provide a prompt."} | |
else: | |
prompt = PromptTemplate(template=prompt, input_variables=['Prompt']) | |
# Initialize the LLM | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-pro", | |
safety_settings={ | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
}, | |
) | |
llmchain = LLMChain( | |
prompt=prompt, | |
llm=llm | |
) | |
llm_response = llmchain.run({"Prompt": prompt}) | |
return Generate(text=llm_response) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} | |
def inference(input_prompt: str): | |
return generate_text(prompt=input_prompt) |