import logging import gradio as gr import wandb import transformers from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import peft from peft import PeftModel import torch wandb.login(key = '926692a60ef1695538fcd1ae6e75482b0741a3d3') wandb.init(project='journal-finetune', entity='benbankston2') # Initialize logging logging.basicConfig(level=logging.INFO) base_model_id = "microsoft/phi-2" model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ), torch_dtype=torch.bfloat16, # FA2 does not work yet # attn_implementation="flash_attention_2", ) #model = pipeline("text-generation", model=model_name) model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175") model.to("cuda") tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False) tokenizer.pad_token = tokenizer.eos_token def generate_text(prompt): logging.info(f"Generating text for prompt: {prompt}") model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100 #response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True) response = tokenizer.decode(model.generate( **model_input, max_new_tokens=256, repetition_penalty=1.11)[0], temperature = 1, eos_token_id=tokenizer.pad_token, skip_special_tokens=True, early_stopping = True, ) #best_response = response[0]['generated_text'] logging.info(f"Generated text: {response}") return response def message_and_history(input_text, history, feedback = None): """Manage message history and generate responses."""#100 if history is None: history = [] history2 = list(sum(history, ())) history2.append(input_text) input = ''.join(history2) output = generate_text(input_text)#input) history.append(("User", input_text)) history.append(("Fizz Bot", output)) return history, history def setup_interface(): with gr.Blocks(css=''' body { font-family: 'Arial', sans-serif; background: #f1f1f1; } .container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } h1 { color: #003F87; text-align: center; margin-bottom: 20px; } .gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; } .gradio-chatbox { background-color: #f0f0f0; } parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */ .gradio-chatbox-message-user { background-color: #4A90E2; color: white; } .gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; } .gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; } ''', theme ="Soft") as block: gr.Markdown("""
""") gr.Markdown("

Fizz Chatbot

") gr.Markdown("
Disclaimer: some information may be inaccurate
") with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row: temperature = gr.Slider( minimum = 0.0, maximum = 1.0, value = 0.7, step=0.1, interactive = True, label="Temperature" ) top_p = gr.Slider( minimum = 0.0, maximum = 1.0, value = 1.0, step=0.1, interactive = True, label="Top P" ) max_new_tokens = gr.Slider( minimum = 16, maximum = 1028, value = 128, step= 32, interactive = True, label="Max tokens" ) chatbot = gr.Chatbot(label="W&L AI") message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box") submit = gr.Button("Submit Query", elem_classes="specific_button") submit.click( fn=message_and_history, inputs=[message, gr.State()], outputs=[chatbot, gr.State()] ) gr.Row([chatbot]) gr.Row([message, submit]) return block app = setup_interface() app.launch(debug=True)