RashiAgarwal commited on
Commit
eaa8416
1 Parent(s): 81e4d50

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
3
+ import gradio as gr
4
+
5
+
6
+ model_name = "microsoft/phi-2"
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ model_name,
9
+ trust_remote_code=True
10
+ )
11
+ model.config.use_cache = False
12
+
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ # Loading adapter (trained LORA weights)
18
+ ckpt = '/content/drive/MyDrive/S27/results/checkpoint-500'
19
+ model.load_adapter(ckpt)
20
+ # adapter_path = 'checkpoint-500'
21
+ # model.load_adapter(adapter_path)
22
+
23
+ def inference(prompt):
24
+
25
+ pipe = pipeline(task="text-generation",model=model,tokenizer=tokenizer,max_length=200)
26
+ result = pipe(f"<s>[INST] {prompt} [/INST]")
27
+ return result[0]['generated_text']
28
+
29
+ with gr.Blocks() as demo:
30
+ prompt = gr.Textbox(label="Prompt")
31
+ output = gr.Textbox(label="Output Box")
32
+ greet_btn = gr.Button("Generate")
33
+ greet_btn.click(fn=inference, inputs=prompt, outputs=output, api_name="inference")
34
+
35
+ demo.launch()