phuongnv commited on
Commit
e4fefdd
1 Parent(s): 36d332d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -17
main.py CHANGED
@@ -1,22 +1,45 @@
1
- from ctransformers import AutoModelForCausalLM
2
- from fastapi import FastAPI, Form
3
  from pydantic import BaseModel
 
 
4
 
5
- #Model loading
6
- llm = AutoModelForCausalLM.from_pretrained("model.gguf", max_new_tokens = 512)
7
-
 
 
 
8
 
9
- #Pydantic object
10
- class validation(BaseModel):
11
  prompt: str
12
-
13
- #Fast API
14
- app = FastAPI()
15
 
16
- #Zephyr completion
17
- @app.post("/llm_on_cpu")
18
- async def stream(item: validation):
19
- E_INST = "</s>"
20
- user, assistant = "<|user|>", "<|assistant|>"
21
- prompt = f"{E_INST}\n{user}\n{item.prompt.strip()}{E_INST}\n{assistant}\n"
22
- return llm(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ctransformers import AutoModelForCausalLM, AutoTokenizer
2
+ from fastapi import FastAPI, Form, HTTPException
3
  from pydantic import BaseModel
4
+ import torch
5
+ import selfies as sf
6
 
7
+ app = FastAPI()
8
+
9
+ # Load the model and tokenizer
10
+ model_name = "model.gguf" # Replace with your model name
11
+ test_model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ test_tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
14
+ class RequestBody(BaseModel):
 
15
  prompt: str
 
 
 
16
 
17
+ @app.post("/generate/")
18
+ async def generate_text(request: RequestBody):
19
+ try:
20
+ prompt = request.prompt
21
+ input_ids = test_tokenizer(prompt, return_tensors='pt', truncation=False).input_ids
22
+ outputs = test_model.generate(
23
+ input_ids=input_ids,
24
+ max_new_tokens=512,
25
+ num_beams=10,
26
+ early_stopping=True,
27
+ num_return_sequences=10,
28
+ do_sample=True
29
+ )
30
+
31
+ result = {'input': prompt}
32
+ for i in range(10):
33
+ output1 = test_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[i][len(prompt):]
34
+ first_inst_index = output1.find("[/INST]")
35
+ second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
36
+ predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
37
+ result[f'predict_{i+1}'] = predicted_selfies
38
+
39
+ return result
40
+ except Exception as e:
41
+ raise HTTPException(status_code=500, detail=str(e))
42
+
43
+ @app.get("/")
44
+ async def read_root():
45
+ return {"message": "Welcome to the LLM FastAPI application!"}