LH-Tech-AI commited on
Commit
56bd50e
·
verified ·
1 Parent(s): b65ba6b

Create use.py

Browse files
Files changed (1) hide show
  1. use.py +46 -0
use.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ def run_pin_inference(prompt, model_id="LH-Tech-AI/Pin-Tiny", subfolder="Pin-25M"):
5
+ # 1. Device Setup
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ print(f"Using device: {device}")
8
+
9
+ # 2. Load tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+
13
+ model = AutoModelForCausalLM.from_pretrained(model_id, subfolder=subfolder).to(device)
14
+
15
+ # 3. Format prompt
16
+ formatted_prompt = f"[INST] {prompt} [/INST]"
17
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
18
+
19
+ # 4. Generate
20
+ with torch.no_grad():
21
+ outputs = model.generate(
22
+ **inputs,
23
+ max_new_tokens=64,
24
+ temperature=0.7,
25
+ do_sample=True,
26
+ pad_token_id=tokenizer.eos_token_id,
27
+ eos_token_id=tokenizer.encode("[")[0]
28
+ )
29
+
30
+ # 5. Decode & Cleanup
31
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+
33
+ if "[/INST]" in full_text:
34
+ response = full_text.split("[/INST]")[-1].split("[INST]")[0].strip()
35
+ else:
36
+ response = full_text
37
+
38
+ return response
39
+
40
+ # --- Sample test ---
41
+ if __name__ == "__main__":
42
+ user_query = "What is the weather like today?"
43
+ answer = run_pin_inference(user_query, model_id="LH-Tech-AI/Pin", subfolder="Pin-25M")
44
+
45
+ print(f"\nUser: {user_query}")
46
+ print(f"Pin: {answer}")