File size: 1,559 Bytes
06558a8
e4fefdd
ee1824b
e4fefdd
 
ee1824b
d01a85e
e4fefdd
 
 
06558a8
ee1824b
e4fefdd
ee1824b
1a4e8c6
 
 
ee1824b
4c4aeca
e4fefdd
 
873aeea
e875ebf
 
e4fefdd
 
 
e875ebf
1a4e8c6
 
 
 
e4fefdd
 
 
e875ebf
 
e4fefdd
 
 
bcbb516
e4fefdd
 
 
 
 
 
 
06558a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from llama_cpp import Llama
from fastapi import FastAPI, Form, HTTPException
from pydantic import BaseModel
import torch
import selfies as sf

app = FastAPI(title="Retrosynthesis Prediction API", version="1.0")

# Load the model and tokenizer
model_name = "model.gguf"  # Replace with your model name
test_model = Llama(model_name)

class RequestBody(BaseModel):
    prompt: str
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 1.0

@app.post("/generate/")
async def generate_text(request: RequestBody):
    try:
        prompt = sf.encoder(request.prompt)
        outputs = test_model(
            prompt,
            max_new_tokens=512, 
            num_beams=10, 
            early_stopping=True, 
            num_return_sequences=10, 
            do_sample=True,
            top_k = request.top_k,
            top_p = request.top_p,
            temperature = request.temperature
        )
        
        result = {'input': prompt}
        for i in range(10):
            output1 = outputs[i][len(prompt):]
            first_inst_index = output1.find("[/INST]")
            second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
            predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
            result[f'predict_{i+1}'] = sf.decoder(predicted_selfies)
        
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def read_root():
    return {"message": "Welcome to the RetroLLM app!"}