infidea commited on
Commit
e2a8021
1 Parent(s): 92758a0

load model first time

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -5,25 +5,27 @@ from pydantic import BaseModel, Field
5
 
6
  class RequestGenerate(BaseModel):
7
  prompt: str
8
- do_sample: bool = Field(default=True, example=True)
9
- top_k: int = Field(default=1, example=1),
10
- temperature: float = Field(default=0.9, example=0.9),
11
- max_new_tokens: int = Field(default=500, example=500),
12
- repetition_penalty: float = Field(default=1.5, example=1.5),
13
 
14
  app = FastAPI()
15
 
 
 
 
 
 
 
16
  @app.get("/")
17
  def greet_json():
18
  return {"Hello": "World!"}
19
 
20
  @app.post("/generate")
21
  def generate(req: RequestGenerate):
22
- model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
23
- # model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5"
24
-
25
- model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
26
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)
27
 
28
  inputs = tokenizer(req.prompt, return_tensors="pt")
29
 
 
5
 
6
  class RequestGenerate(BaseModel):
7
  prompt: str
8
+ do_sample: bool = Field(default=bool(True), example=True)
9
+ top_k: int = Field(default=int(1), example=1),
10
+ temperature: float = Field(default=float(0.9), example=0.9),
11
+ max_new_tokens: int = Field(default=int(500), example=500),
12
+ repetition_penalty: float = Field(default=float(1.5), example=1.5),
13
 
14
  app = FastAPI()
15
 
16
+ # model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
17
+ model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5"
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)
21
+
22
  @app.get("/")
23
  def greet_json():
24
  return {"Hello": "World!"}
25
 
26
  @app.post("/generate")
27
  def generate(req: RequestGenerate):
28
+
 
 
 
 
29
 
30
  inputs = tokenizer(req.prompt, return_tensors="pt")
31