Spaces:
Runtime error
Runtime error
import logging | |
import gradio as gr | |
import wandb | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import peft | |
from peft import PeftModel | |
import torch | |
wandb.login(key = '926692a60ef1695538fcd1ae6e75482b0741a3d3') | |
wandb.init(project='journal-finetune', entity='benbankston2') | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
base_model_id = "microsoft/phi-2" | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
device_map="auto", | |
quantization_config=BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
), | |
torch_dtype=torch.bfloat16, | |
# FA2 does not work yet | |
# attn_implementation="flash_attention_2", | |
) | |
#model = pipeline("text-generation", model=model_name) | |
model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175") | |
model.to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False) | |
tokenizer.pad_token = tokenizer.eos_token | |
def generate_text(prompt): | |
logging.info(f"Generating text for prompt: {prompt}") | |
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100 | |
#response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True) | |
response = tokenizer.decode(model.generate( | |
**model_input, max_new_tokens=256, | |
repetition_penalty=1.11)[0], | |
temperature = 1, | |
eos_token_id=tokenizer.pad_token, | |
skip_special_tokens=True, | |
early_stopping = True, | |
) | |
#best_response = response[0]['generated_text'] | |
logging.info(f"Generated text: {response}") | |
return response | |
def message_and_history(input_text, history, feedback = None): | |
"""Manage message history and generate responses."""#100 | |
if history is None: | |
history = [] | |
history2 = list(sum(history, ())) | |
history2.append(input_text) | |
input = ''.join(history2) | |
output = generate_text(input_text)#input) | |
history.append(("User", input_text)) | |
history.append(("Fizz Bot", output)) | |
return history, history | |
def setup_interface(): | |
with gr.Blocks(css=''' | |
body { font-family: 'Arial', sans-serif; background: #f1f1f1; } | |
.container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } | |
h1 { color: #003F87; text-align: center; margin-bottom: 20px; } | |
.gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; } | |
.gradio-chatbox { background-color: #f0f0f0; } | |
parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */ | |
.gradio-chatbox-message-user { background-color: #4A90E2; color: white; } | |
.gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; } | |
.gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; } | |
''', theme ="Soft") as block: | |
gr.Markdown(""" | |
<div style="background-image: url('https://my.wlu.edu/Images/communications/publications/graphic-identity/300-dpi-wordmark-blue.png'); | |
background-size: contain; | |
background-repeat: no-repeat; | |
background-position: center; | |
text-align: center; | |
height: 100px; | |
line-height: 100px; | |
font-size: 36px; | |
color: white; | |
font-family: Arial, sans-serif;"> | |
</div> | |
""") | |
gr.Markdown("<h1>Fizz Chatbot</h1>") | |
gr.Markdown("<h6><i>Disclaimer: some information may be inaccurate</i></h6>") | |
with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row: | |
temperature = gr.Slider( | |
minimum = 0.0, | |
maximum = 1.0, | |
value = 0.7, | |
step=0.1, | |
interactive = True, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum = 0.0, | |
maximum = 1.0, | |
value = 1.0, | |
step=0.1, | |
interactive = True, | |
label="Top P" | |
) | |
max_new_tokens = gr.Slider( | |
minimum = 16, | |
maximum = 1028, | |
value = 128, | |
step= 32, | |
interactive = True, | |
label="Max tokens" | |
) | |
chatbot = gr.Chatbot(label="W&L AI") | |
message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box") | |
submit = gr.Button("Submit Query", elem_classes="specific_button") | |
submit.click( | |
fn=message_and_history, | |
inputs=[message, gr.State()], | |
outputs=[chatbot, gr.State()] | |
) | |
gr.Row([chatbot]) | |
gr.Row([message, submit]) | |
return block | |
app = setup_interface() | |
app.launch(debug=True) |