import logging import gradio as gr import wandb import transformers from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import peft from peft import PeftModel import torch wandb.login(key = '926692a60ef1695538fcd1ae6e75482b0741a3d3') wandb.init(project='journal-finetune', entity='benbankston2') # Initialize logging logging.basicConfig(level=logging.INFO) base_model_id = "microsoft/phi-2" model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ), torch_dtype=torch.bfloat16, # FA2 does not work yet # attn_implementation="flash_attention_2", ) #model = pipeline("text-generation", model=model_name) model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175") model.to("cuda") tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False) tokenizer.pad_token = tokenizer.eos_token def generate_text(prompt): logging.info(f"Generating text for prompt: {prompt}") model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100 #response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True) response = tokenizer.decode(model.generate( **model_input, max_new_tokens=256, repetition_penalty=1.11)[0], temperature = 1, eos_token_id=tokenizer.pad_token, skip_special_tokens=True, early_stopping = True, ) #best_response = response[0]['generated_text'] logging.info(f"Generated text: {response}") return response def message_and_history(input_text, history, feedback = None): """Manage message history and generate responses."""#100 if history is None: history = [] history2 = list(sum(history, ())) history2.append(input_text) input = ''.join(history2) output = generate_text(input_text)#input) history.append(("User", input_text)) history.append(("Fizz Bot", output)) return history, history def setup_interface(): with gr.Blocks(css=''' body { font-family: 'Arial', sans-serif; background: #f1f1f1; } .container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } h1 { color: #003F87; text-align: center; margin-bottom: 20px; } .gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; } .gradio-chatbox { background-color: #f0f0f0; } parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */ .gradio-chatbox-message-user { background-color: #4A90E2; color: white; } .gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; } .gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; } ''', theme ="Soft") as block: gr.Markdown("""