WLUChatbot / app.py
lehmanc25's picture
working version
25485b9 verified
raw
history blame
No virus
5.17 kB
import logging
import gradio as gr
import wandb
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import peft
from peft import PeftModel
import torch
wandb.login()
wandb.init(project='journal-finetune', entity='benbankston2')
# Initialize logging
logging.basicConfig(level=logging.INFO)
base_model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
# FA2 does not work yet
# attn_implementation="flash_attention_2",
)
#model = pipeline("text-generation", model=model_name)
model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175")
model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt):
logging.info(f"Generating text for prompt: {prompt}")
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100
#response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True)
response = tokenizer.decode(model.generate(
**model_input, max_new_tokens=256,
repetition_penalty=1.11)[0],
temperature = 1,
eos_token_id=tokenizer.pad_token,
skip_special_tokens=True,
early_stopping = True,
)
#best_response = response[0]['generated_text']
logging.info(f"Generated text: {response}")
return response
def message_and_history(input_text, history, feedback = None):
"""Manage message history and generate responses."""#100
if history is None:
history = []
history2 = list(sum(history, ()))
history2.append(input_text)
input = ''.join(history2)
output = generate_text(input_text)#input)
history.append(("User", input_text))
history.append(("Fizz Bot", output))
return history, history
def setup_interface():
with gr.Blocks(css='''
body { font-family: 'Arial', sans-serif; background: #f1f1f1; }
.container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); }
h1 { color: #003F87; text-align: center; margin-bottom: 20px; }
.gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; }
.gradio-chatbox { background-color: #f0f0f0; }
parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */
.gradio-chatbox-message-user { background-color: #4A90E2; color: white; }
.gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; }
.gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; }
''', theme ="Soft") as block:
gr.Markdown("""
<div style="background-image: url('https://my.wlu.edu/Images/communications/publications/graphic-identity/300-dpi-wordmark-blue.png');
background-size: contain;
background-repeat: no-repeat;
background-position: center;
text-align: center;
height: 100px;
line-height: 100px;
font-size: 36px;
color: white;
font-family: Arial, sans-serif;">
</div>
""")
gr.Markdown("<h1>Fizz Chatbot</h1>")
gr.Markdown("<h6><i>Disclaimer: some information may be inaccurate</i></h6>")
with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row:
temperature = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 0.7,
step=0.1,
interactive = True,
label="Temperature"
)
top_p = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 1.0,
step=0.1,
interactive = True,
label="Top P"
)
max_new_tokens = gr.Slider(
minimum = 16,
maximum = 1028,
value = 128,
step= 32,
interactive = True,
label="Max tokens"
)
chatbot = gr.Chatbot(label="W&L AI")
message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box")
submit = gr.Button("Submit Query", elem_classes="specific_button")
submit.click(
fn=message_and_history,
inputs=[message, gr.State()],
outputs=[chatbot, gr.State()]
)
gr.Row([chatbot])
gr.Row([message, submit])
return block
app = setup_interface()
app.launch(debug=True)