phuongnv's picture
Update main.py
e875ebf verified
raw
history blame contribute delete
No virus
1.56 kB
from llama_cpp import Llama
from fastapi import FastAPI, Form, HTTPException
from pydantic import BaseModel
import torch
import selfies as sf
app = FastAPI(title="Retrosynthesis Prediction API", version="1.0")
# Load the model and tokenizer
model_name = "model.gguf" # Replace with your model name
test_model = Llama(model_name)
class RequestBody(BaseModel):
prompt: str
temperature: float = 1.0
top_k: int = 50
top_p: float = 1.0
@app.post("/generate/")
async def generate_text(request: RequestBody):
try:
prompt = sf.encoder(request.prompt)
outputs = test_model(
prompt,
max_new_tokens=512,
num_beams=10,
early_stopping=True,
num_return_sequences=10,
do_sample=True,
top_k = request.top_k,
top_p = request.top_p,
temperature = request.temperature
)
result = {'input': prompt}
for i in range(10):
output1 = outputs[i][len(prompt):]
first_inst_index = output1.find("[/INST]")
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
result[f'predict_{i+1}'] = sf.decoder(predicted_selfies)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def read_root():
return {"message": "Welcome to the RetroLLM app!"}