|
from llama_cpp import Llama |
|
from fastapi import FastAPI, Form, HTTPException |
|
from pydantic import BaseModel |
|
import torch |
|
import selfies as sf |
|
|
|
app = FastAPI(title="Retrosynthesis Prediction API", version="1.0") |
|
|
|
|
|
model_name = "model.gguf" |
|
test_model = Llama(model_name) |
|
|
|
class RequestBody(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/generate/") |
|
async def generate_text(request: RequestBody): |
|
try: |
|
prompt = sf.encoder(request.prompt) |
|
input_ids = test_tokenizer(prompt, return_tensors='pt', truncation=False).input_ids |
|
outputs = test_model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=512, |
|
num_beams=10, |
|
early_stopping=True, |
|
num_return_sequences=10, |
|
do_sample=True |
|
) |
|
|
|
result = {'input': prompt} |
|
for i in range(10): |
|
output1 = test_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[i][len(prompt):] |
|
first_inst_index = output1.find("[/INST]") |
|
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1) |
|
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip() |
|
result[f'predict_{i+1}'] = predicted_selfies |
|
|
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/") |
|
async def read_root(): |
|
return {"message": "Welcome to the RetroLLM app!"} |
|
|