from llama_cpp import Llama from fastapi import FastAPI, Form, HTTPException from pydantic import BaseModel import torch import selfies as sf app = FastAPI(title="Retrosynthesis Prediction API", version="1.0") # Load the model and tokenizer model_name = "model.gguf" # Replace with your model name test_model = Llama(model_name) class RequestBody(BaseModel): prompt: str @app.post("/generate/") async def generate_text(request: RequestBody): try: prompt = sf.encoder(request.prompt) input_ids = test_tokenizer(prompt, return_tensors='pt', truncation=False).input_ids outputs = test_model.generate( input_ids=input_ids, max_new_tokens=512, num_beams=10, early_stopping=True, num_return_sequences=10, do_sample=True ) result = {'input': prompt} for i in range(10): output1 = test_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[i][len(prompt):] first_inst_index = output1.find("[/INST]") second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1) predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip() result[f'predict_{i+1}'] = predicted_selfies return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def read_root(): return {"message": "Welcome to the RetroLLM app!"}