Chengxb888 commited on
Commit
576bbe0
·
verified ·
1 Parent(s): 47c0611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -30
app.py CHANGED
@@ -11,36 +11,15 @@ def greet_json():
11
  @app.get("/hello/{msg}")
12
  def say_hello(msg: str):
13
  print("model")
14
- torch.random.manual_seed(0)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- "microsoft/Phi-3-mini-4k-instruct",
17
- device_map="auto",
18
- torch_dtype="auto",
19
- trust_remote_code=True,
20
- )
21
  print("token & msg")
22
- tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
23
-
24
- messages = [
25
- {"role": "system", "content": "You are a helpful AI assistant."},
26
- {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
27
- {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
28
- {"role": "user", "content": msg},
29
- ]
30
- print("pipe")
31
- pipe = pipeline(
32
- "text-generation",
33
- model=model,
34
- tokenizer=tokenizer,
35
- )
36
  print("output")
37
- # generation_args = {
38
- # "max_new_tokens": 500,
39
- # "return_full_text": False,
40
- # "temperature": 0.0,
41
- # "do_sample": False,
42
- # }
43
-
44
- output = pipe(messages) #, **generation_args)
45
  print("complete")
46
- return {"message": output[0]['generated_text']}
 
11
  @app.get("/hello/{msg}")
12
  def say_hello(msg: str):
13
  print("model")
14
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ "google/gemma-2b-it",
17
+ device_map="auto",
18
+ torch_dtype=torch.bfloat16
19
+ )
 
20
  print("token & msg")
21
+ input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  print("output")
23
+ outputs = model.generate(**input_ids, max_length=500)
 
 
 
 
 
 
 
24
  print("complete")
25
+ return {"message": tokenizer.decode(outputs[0])}