phuongnv's picture
Update main.py
d01a85e verified
raw
history blame
1.55 kB
from llama_cpp import Llama
from fastapi import FastAPI, Form, HTTPException
from pydantic import BaseModel
import torch
import selfies as sf
app = FastAPI(title="Retrosynthesis Prediction API", version="1.0")
# Load the model and tokenizer
model_name = "model.gguf" # Replace with your model name
test_model = Llama(model_name)
class RequestBody(BaseModel):
prompt: str
@app.post("/generate/")
async def generate_text(request: RequestBody):
try:
prompt = sf.encoder(request.prompt)
input_ids = test_tokenizer(prompt, return_tensors='pt', truncation=False).input_ids
outputs = test_model.generate(
input_ids=input_ids,
max_new_tokens=512,
num_beams=10,
early_stopping=True,
num_return_sequences=10,
do_sample=True
)
result = {'input': prompt}
for i in range(10):
output1 = test_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[i][len(prompt):]
first_inst_index = output1.find("[/INST]")
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
result[f'predict_{i+1}'] = predicted_selfies
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def read_root():
return {"message": "Welcome to the RetroLLM app!"}