lehmanc25 commited on
Commit
25485b9
1 Parent(s): b50d725

working version

Browse files
Files changed (1) hide show
  1. app.py +129 -24
app.py CHANGED
@@ -1,29 +1,134 @@
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
3
 
4
- # Assuming you've uploaded your model to Hugging Face with the username 'lehmanc25'
5
- tokenizer = AutoTokenizer.from_pretrained("lehmanc25/my-phi2-model")
6
- model = AutoModelForCausalLM.from_pretrained("lehmanc25/my-phi2-model")
7
 
8
- def chat_with_phi2(input_text):
9
- encoded_input = tokenizer(input_text, return_tensors='pt')
10
- output = model.generate(**encoded_input, max_length=50)
11
- response = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  return response
13
 
14
- # Using Gradio Blocks to build a more dynamic interface
15
- block = gr.Blocks()
16
-
17
- with block:
18
- gr.Markdown("# W&L ChatGPT-like AI")
19
- gr.Markdown("Ask any question about Washington and Lee University!")
20
- input_text = gr.Textbox(label="Your question:", lines=2, placeholder="Ask something about W&L...")
21
- output_text = gr.Text(label="AI Response:")
22
- button = gr.Button("Ask")
23
- button.click(
24
- fn=chat_with_phi2,
25
- inputs=input_text,
26
- outputs=output_text
27
- )
28
-
29
- block.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
  import gradio as gr
3
+ import wandb
4
+ import transformers
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
+ import peft
7
+ from peft import PeftModel
8
+ import torch
9
 
10
+ wandb.login()
11
+ wandb.init(project='journal-finetune', entity='benbankston2')
 
12
 
13
+
14
+ # Initialize logging
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ base_model_id = "microsoft/phi-2"
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ base_model_id,
20
+ device_map="auto",
21
+ quantization_config=BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_compute_dtype=torch.bfloat16,
24
+ bnb_4bit_quant_type="nf4",
25
+ ),
26
+ torch_dtype=torch.bfloat16,
27
+ # FA2 does not work yet
28
+ # attn_implementation="flash_attention_2",
29
+ )
30
+
31
+
32
+ #model = pipeline("text-generation", model=model_name)
33
+ model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175")
34
+ model.to("cuda")
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+
39
+ def generate_text(prompt):
40
+ logging.info(f"Generating text for prompt: {prompt}")
41
+ model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100
42
+ #response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True)
43
+ response = tokenizer.decode(model.generate(
44
+ **model_input, max_new_tokens=256,
45
+ repetition_penalty=1.11)[0],
46
+ temperature = 1,
47
+ eos_token_id=tokenizer.pad_token,
48
+ skip_special_tokens=True,
49
+ early_stopping = True,
50
+ )
51
+ #best_response = response[0]['generated_text']
52
+ logging.info(f"Generated text: {response}")
53
  return response
54
 
55
+ def message_and_history(input_text, history, feedback = None):
56
+ """Manage message history and generate responses."""#100
57
+ if history is None:
58
+ history = []
59
+ history2 = list(sum(history, ()))
60
+ history2.append(input_text)
61
+ input = ''.join(history2)
62
+ output = generate_text(input_text)#input)
63
+ history.append(("User", input_text))
64
+ history.append(("Fizz Bot", output))
65
+ return history, history
66
+
67
+ def setup_interface():
68
+ with gr.Blocks(css='''
69
+ body { font-family: 'Arial', sans-serif; background: #f1f1f1; }
70
+ .container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); }
71
+ h1 { color: #003F87; text-align: center; margin-bottom: 20px; }
72
+ .gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; }
73
+ .gradio-chatbox { background-color: #f0f0f0; }
74
+ parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */
75
+ .gradio-chatbox-message-user { background-color: #4A90E2; color: white; }
76
+ .gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; }
77
+ .gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; }
78
+ ''', theme ="Soft") as block:
79
+
80
+ gr.Markdown("""
81
+ <div style="background-image: url('https://my.wlu.edu/Images/communications/publications/graphic-identity/300-dpi-wordmark-blue.png');
82
+ background-size: contain;
83
+ background-repeat: no-repeat;
84
+ background-position: center;
85
+ text-align: center;
86
+ height: 100px;
87
+ line-height: 100px;
88
+ font-size: 36px;
89
+ color: white;
90
+ font-family: Arial, sans-serif;">
91
+ </div>
92
+ """)
93
+ gr.Markdown("<h1>Fizz Chatbot</h1>")
94
+ gr.Markdown("<h6><i>Disclaimer: some information may be inaccurate</i></h6>")
95
+ with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row:
96
+ temperature = gr.Slider(
97
+ minimum = 0.0,
98
+ maximum = 1.0,
99
+ value = 0.7,
100
+ step=0.1,
101
+ interactive = True,
102
+ label="Temperature"
103
+ )
104
+ top_p = gr.Slider(
105
+ minimum = 0.0,
106
+ maximum = 1.0,
107
+ value = 1.0,
108
+ step=0.1,
109
+ interactive = True,
110
+ label="Top P"
111
+ )
112
+ max_new_tokens = gr.Slider(
113
+ minimum = 16,
114
+ maximum = 1028,
115
+ value = 128,
116
+ step= 32,
117
+ interactive = True,
118
+ label="Max tokens"
119
+ )
120
+ chatbot = gr.Chatbot(label="W&L AI")
121
+ message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box")
122
+ submit = gr.Button("Submit Query", elem_classes="specific_button")
123
+ submit.click(
124
+ fn=message_and_history,
125
+ inputs=[message, gr.State()],
126
+ outputs=[chatbot, gr.State()]
127
+ )
128
+ gr.Row([chatbot])
129
+ gr.Row([message, submit])
130
+
131
+ return block
132
+
133
+ app = setup_interface()
134
+ app.launch(debug=True)