WLUChatbot / app.py
lehmanc25's picture
Update app.py
8f4eb15 verified
import logging
import gradio as gr
import wandb
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import peft
from peft import PeftModel
import torch
wandb.login(key = '926692a60ef1695538fcd1ae6e75482b0741a3d3')
wandb.init(project='journal-finetune', entity='benbankston2')
# Initialize logging
logging.basicConfig(level=logging.INFO)
base_model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
# FA2 does not work yet
# attn_implementation="flash_attention_2",
)
#model = pipeline("text-generation", model=model_name)
model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175")
model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt):
logging.info(f"Generating text for prompt: {prompt}")
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100
#response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True)
response = tokenizer.decode(model.generate(
**model_input, max_new_tokens=256,
repetition_penalty=1.11)[0],
temperature = 1,
eos_token_id=tokenizer.pad_token,
skip_special_tokens=True,
early_stopping = True,
)
#best_response = response[0]['generated_text']
logging.info(f"Generated text: {response}")
return response
def message_and_history(input_text, history, feedback = None):
"""Manage message history and generate responses."""#100
if history is None:
history = []
history2 = list(sum(history, ()))
history2.append(input_text)
input = ''.join(history2)
output = generate_text(input_text)#input)
history.append(("User", input_text))
history.append(("Fizz Bot", output))
return history, history
def setup_interface():
with gr.Blocks(css='''
body { font-family: 'Arial', sans-serif; background: #f1f1f1; }
.container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); }
h1 { color: #003F87; text-align: center; margin-bottom: 20px; }
.gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; }
.gradio-chatbox { background-color: #f0f0f0; }
parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */
.gradio-chatbox-message-user { background-color: #4A90E2; color: white; }
.gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; }
.gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; }
''', theme ="Soft") as block:
gr.Markdown("""
<div style="background-image: url('https://my.wlu.edu/Images/communications/publications/graphic-identity/300-dpi-wordmark-blue.png');
background-size: contain;
background-repeat: no-repeat;
background-position: center;
text-align: center;
height: 100px;
line-height: 100px;
font-size: 36px;
color: white;
font-family: Arial, sans-serif;">
</div>
""")
gr.Markdown("<h1>Fizz Chatbot</h1>")
gr.Markdown("<h6><i>Disclaimer: some information may be inaccurate</i></h6>")
with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row:
temperature = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 0.7,
step=0.1,
interactive = True,
label="Temperature"
)
top_p = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 1.0,
step=0.1,
interactive = True,
label="Top P"
)
max_new_tokens = gr.Slider(
minimum = 16,
maximum = 1028,
value = 128,
step= 32,
interactive = True,
label="Max tokens"
)
chatbot = gr.Chatbot(label="W&L AI")
message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box")
submit = gr.Button("Submit Query", elem_classes="specific_button")
submit.click(
fn=message_and_history,
inputs=[message, gr.State()],
outputs=[chatbot, gr.State()]
)
gr.Row([chatbot])
gr.Row([message, submit])
return block
app = setup_interface()
app.launch(debug=True)