|
from llama_cpp import Llama |
|
from fastapi import FastAPI, Form, HTTPException |
|
from pydantic import BaseModel |
|
import torch |
|
import selfies as sf |
|
|
|
app = FastAPI(title="Retrosynthesis Prediction API", version="1.0") |
|
|
|
|
|
model_name = "model.gguf" |
|
test_model = Llama(model_name) |
|
|
|
class RequestBody(BaseModel): |
|
prompt: str |
|
temperature: float = 1.0 |
|
top_k: int = 50 |
|
top_p: float = 1.0 |
|
|
|
@app.post("/generate/") |
|
async def generate_text(request: RequestBody): |
|
try: |
|
prompt = sf.encoder(request.prompt) |
|
outputs = test_model( |
|
prompt, |
|
max_new_tokens=512, |
|
num_beams=10, |
|
early_stopping=True, |
|
num_return_sequences=10, |
|
do_sample=True, |
|
top_k = request.top_k, |
|
top_p = request.top_p, |
|
temperature = request.temperature |
|
) |
|
|
|
result = {'input': prompt} |
|
for i in range(10): |
|
output1 = outputs[i][len(prompt):] |
|
first_inst_index = output1.find("[/INST]") |
|
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1) |
|
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip() |
|
result[f'predict_{i+1}'] = sf.decoder(predicted_selfies) |
|
|
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/") |
|
async def read_root(): |
|
return {"message": "Welcome to the RetroLLM app!"} |
|
|