CHEMISTral7Bv0.3 / chemistral_api.py
Clemspace's picture
added inference + api wrapper
32fe622
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
from typing import Optional
import subprocess
import os
import logging
from inference_transform import process_smiles, process_pdb, process_sdf, extract_and_convert_to_sdf, is_valid_smiles
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*']
)
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
class InferenceRequest(BaseModel):
prompt: str
max_tokens: int = 256
temperature: float = 1.0
@app.post("/predict_base")
async def predict_base(
prompt: str = Form(...),
max_tokens: int = Form(256),
temperature: float = Form(1.0),
file: Optional[UploadFile] = File(None)
):
try:
if file:
file_path = f"/tmp/{file.filename}"
with open(file_path, "wb") as f:
f.write(file.file.read())
if file.filename.endswith(".pdb"):
prompt += f" {process_pdb(file_path)}"
elif file.filename.endswith(".sdf"):
prompt += f" {process_sdf(file_path)}"
else:
try:
sdf_file = extract_and_convert_to_sdf(prompt)
if sdf_file:
prompt += f" {sdf_file}"
except ValueError as e:
logger.info(str(e))
command = [
"python",
"/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
"/root/mistral_models/7B-v0.3/",
prompt,
f"--max_tokens={max_tokens}",
f"--temperature={temperature}",
"--instruct"
]
logger.info(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"Command failed with return code {result.returncode}")
logger.error(f"stderr: {result.stderr}")
raise HTTPException(status_code=500, detail=result.stderr)
response = result.stdout.strip()
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
return {
"response": response,
"sdf_file_path": sdf_file_path
}
except Exception as e:
logger.exception("Exception occurred during inference.")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
async def predict_alternative(
prompt: str = Form(...),
max_tokens: int = Form(256),
temperature: float = Form(1.0),
file: Optional[UploadFile] = File(None)
):
try:
if file:
file_path = f"/tmp/{file.filename}"
with open(file_path, "wb") as f:
f.write(await file.read())
if file.filename.endswith(".pdb"):
prompt += f" {process_pdb(file_path)}"
elif file.filename.endswith(".sdf"):
prompt += f" {process_sdf(file_path)}"
else:
try:
sdf_file = extract_and_convert_to_sdf(prompt)
if sdf_file:
prompt += f" {sdf_file}"
except ValueError as e:
logger.info(str(e))
command = [
"python",
"/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
"/root/mistral_models/7B-v0.3/",
prompt,
f"--max_tokens={max_tokens}",
f"--temperature={temperature}",
"--instruct",
"--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
]
logger.info(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"Command failed with return code {result.returncode}")
logger.error(f"stderr: {result.stderr}")
raise HTTPException(status_code=500, detail=result.stderr)
response = result.stdout.strip()
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
# Return the file as a direct download
return FileResponse(sdf_file_path, media_type='chemical/x-mdl-sdfile', filename="Conformer3D_COMPOUND_CID_240.sdf")
except Exception as e:
logger.exception("Exception occurred during inference.")
raise HTTPException(status_code=500, detail=str(e))
# @app.post("/predict")
# async def predict_alternative(
# prompt: str = Form(...),
# max_tokens: int = Form(256),
# temperature: float = Form(1.0),
# file: Optional[UploadFile] = File(None)
# ):
# try:
# global sdf_file_path
# if file:
# file_path = f"/tmp/{file.filename}"
# with open(file_path, "wb") as f:
# f.write(file.file.read())
# if file.filename.endswith(".pdb"):
# prompt += f" {process_pdb(file_path)}"
# elif file.filename.endswith(".sdf"):
# prompt += f" {process_sdf(file_path)}"
# else:
# try:
# sdf_file = extract_and_convert_to_sdf(prompt)
# if sdf_file:
# prompt += f" {sdf_file}"
# except ValueError as e:
# logger.info(str(e))
# command = [
# "python",
# "/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
# "/root/mistral_models/7B-v0.3/",
# prompt,
# f"--max_tokens={max_tokens}",
# f"--temperature={temperature}",
# "--instruct",
# "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
# ]
# logger.info(f"Running command: {' '.join(command)}")
# result = subprocess.run(command, capture_output=True, text=True)
# if result.returncode != 0:
# logger.error(f"Command failed with return code {result.returncode}")
# logger.error(f"stderr: {result.stderr}")
# raise HTTPException(status_code=500, detail=result.stderr)
# response = result.stdout.strip()
# sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
# return {
# "response": response,
# "sdf_file_path": sdf_file_path
# }
# except Exception as e:
# logger.exception("Exception occurred during inference.")
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/download_sdf")
async def download_sdf():
try:
return FileResponse(path=sdf_file_path, filename="Conformer3D_COMPOUND_CID_240.sdf")
except Exception as e:
logger.exception("Exception occurred while sending SDF file.")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)