Spaces:
Runtime error
Runtime error
File size: 5,172 Bytes
25485b9 53e9d10 25485b9 a9a06b0 25485b9 53e9d10 25485b9 a5de796 53e9d10 25485b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import logging
import gradio as gr
import wandb
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import peft
from peft import PeftModel
import torch
wandb.login()
wandb.init(project='journal-finetune', entity='benbankston2')
# Initialize logging
logging.basicConfig(level=logging.INFO)
base_model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
# FA2 does not work yet
# attn_implementation="flash_attention_2",
)
#model = pipeline("text-generation", model=model_name)
model = PeftModel.from_pretrained(model, "phi2-journal-finetune/checkpoint-175")
model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt):
logging.info(f"Generating text for prompt: {prompt}")
model_input = tokenizer(prompt, return_tensors="pt").to("cuda")#100
#response = model(prompt, max_new_tokens=100, temperature=0.6, top_p=0.8, repetition_penalty=2.5, do_sample=True)
response = tokenizer.decode(model.generate(
**model_input, max_new_tokens=256,
repetition_penalty=1.11)[0],
temperature = 1,
eos_token_id=tokenizer.pad_token,
skip_special_tokens=True,
early_stopping = True,
)
#best_response = response[0]['generated_text']
logging.info(f"Generated text: {response}")
return response
def message_and_history(input_text, history, feedback = None):
"""Manage message history and generate responses."""#100
if history is None:
history = []
history2 = list(sum(history, ()))
history2.append(input_text)
input = ''.join(history2)
output = generate_text(input_text)#input)
history.append(("User", input_text))
history.append(("Fizz Bot", output))
return history, history
def setup_interface():
with gr.Blocks(css='''
body { font-family: 'Arial', sans-serif; background: #f1f1f1; }
.container { max-width: 800px; margin: auto; padding: 20px; background-size: cover; background-repeat: no-repeat; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); }
h1 { color: #003F87; text-align: center; margin-bottom: 20px; }
.gr-textbox { box-shadow: inset 0 2px 3px rgba(0,0,0,0.1); border-radius: 4px; border: 1px solid #7FB2E5; padding: 10px; width: auto; background-color: #fff; }
.gradio-chatbox { background-color: #f0f0f0; }
parameter-accordion .gr-accordion-title { font-weight: bold; font-size: 18px; } /* Custom CSS for accordion title */
.gradio-chatbox-message-user { background-color: #4A90E2; color: white; }
.gradio-chatbox-message-bot { background-color: #FFFFFF; color: black; }
.gradio-chatbox-message { border-radius: 10px; padding: 10px; margin-bottom: 8px; }
''', theme ="Soft") as block:
gr.Markdown("""
<div style="background-image: url('https://my.wlu.edu/Images/communications/publications/graphic-identity/300-dpi-wordmark-blue.png');
background-size: contain;
background-repeat: no-repeat;
background-position: center;
text-align: center;
height: 100px;
line-height: 100px;
font-size: 36px;
color: white;
font-family: Arial, sans-serif;">
</div>
""")
gr.Markdown("<h1>Fizz Chatbot</h1>")
gr.Markdown("<h6><i>Disclaimer: some information may be inaccurate</i></h6>")
with gr.Accordion("Parameters", open=False, visible=True, elem_classes=["parameter-accordion"]) as parameter_row:
temperature = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 0.7,
step=0.1,
interactive = True,
label="Temperature"
)
top_p = gr.Slider(
minimum = 0.0,
maximum = 1.0,
value = 1.0,
step=0.1,
interactive = True,
label="Top P"
)
max_new_tokens = gr.Slider(
minimum = 16,
maximum = 1028,
value = 128,
step= 32,
interactive = True,
label="Max tokens"
)
chatbot = gr.Chatbot(label="W&L AI")
message = gr.Textbox(label="", placeholder="Ask me anything about W&L here...", elem_id="input_box")
submit = gr.Button("Submit Query", elem_classes="specific_button")
submit.click(
fn=message_and_history,
inputs=[message, gr.State()],
outputs=[chatbot, gr.State()]
)
gr.Row([chatbot])
gr.Row([message, submit])
return block
app = setup_interface()
app.launch(debug=True) |