bala1802 commited on
Commit
70846c9
1 Parent(s): 529bc21

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -0
inference.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ import config
4
+
5
+ def predict(prompt, model, tokenizer, max_length):
6
+ pipe = pipeline(task = config.TASK,
7
+ model = model,
8
+ tokenizer = tokenizer,
9
+ max_length = max_length)
10
+ result = pipe(f"<s>[INST] {prompt} [/INST]")
11
+ return result[0]['generated_text']
12
+