Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import traceback | |
model_name = "unsloth/llama-3-8b-bnb-4bit" | |
# Initialize tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Initialize BitsAndBytesConfig if needed | |
bnb_config = BitsAndBytesConfig(load_in_4bit=True) | |
# Load model with the correct configuration | |
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config) | |
model.eval() | |
ERROR_MSG = "An error occurred, please try again." | |
def chat(img, msgs, params=None): | |
default_params = {"stream": False, "sampling": False, "num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024} | |
if params is None: | |
params = default_params | |
if img is None: | |
return "Error, invalid image, please upload a new image" | |
else: | |
try: | |
image = img.convert('RGB') | |
# Convert image to float32 tensor | |
image_tensor = torch.tensor(np.array(image)).float().div(255).unsqueeze(0).permute(0, 3, 1, 2).to(device='cuda') | |
answer = model.chat( | |
image=image_tensor, | |
msgs=msgs, | |
tokenizer=tokenizer, | |
**params | |
) | |
generated_text = "" | |
for char in answer: | |
generated_text += char | |
return generated_text | |
except Exception as err: | |
print(err) | |
traceback.print_exc() | |
return ERROR_MSG | |
def upload_img(image, _app_session): | |
image = Image.fromarray(image) | |
_app_session['sts'] = None | |
_app_session['ctx'] = [] | |
_app_session['img'] = image | |
return 'Image uploaded successfully, you can talk to me now', _app_session | |
def respond(question, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): | |
if _app_cfg.get('img', None) is None: | |
return 'Please upload an image to start', _app_cfg | |
else: | |
_context = _app_cfg['ctx'].copy() | |
if _context: | |
_context.append({"role": "user", "content": question}) | |
else: | |
_context = [{"role": "user", "content": question}] | |
if params_form == 'Beam Search': | |
params = { | |
'sampling': False, | |
'stream': False, | |
'num_beams': num_beams, | |
'repetition_penalty': repetition_penalty, | |
"max_new_tokens": 896 | |
} | |
else: | |
params = { | |
'sampling': True, | |
'stream': True, | |
'top_p': top_p, | |
'top_k': top_k, | |
'temperature': temperature, | |
'repetition_penalty': repetition_penalty_2, | |
"max_new_tokens": 896 | |
} | |
response = chat(_app_cfg['img'], _context, params) | |
_app_cfg['ctx'] = _context | |
return response, _app_cfg | |
app_cfg = {} | |
# Define the Gradio interface | |
image_upload = gr.Image(type="numpy", label="Upload an image") | |
question_input = gr.Textbox(label="Enter your question") | |
params_form = gr.Radio(choices=["Beam Search", "Sampling"], label="Parameters Form", value="Beam Search") | |
num_beams = gr.Slider(2, 10, step=1, value=3, label="Number of Beams") | |
repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") | |
repetition_penalty_2 = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty (Sampling)") | |
top_p = gr.Slider(0.5, 1.0, step=0.1, value=0.9, label="Top P") | |
top_k = gr.Slider(10, 50, step=5, value=40, label="Top K") | |
temperature = gr.Slider(0.5, 1.0, step=0.1, value=0.7, label="Temperature") | |
chatbot_output = gr.Textbox(label="Response") | |
app_cfg_store = gr.State(app_cfg) | |
upload_button = gr.Button("Upload Image") | |
upload_button.click(upload_img, inputs=[image_upload, app_cfg_store], outputs=[chatbot_output, app_cfg_store]) | |
respond_button = gr.Button("Send Question") | |
respond_button.click( | |
respond, | |
inputs=[question_input, app_cfg_store, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], | |
outputs=[chatbot_output, app_cfg_store] | |
) | |
demo = gr.Interface( | |
fn=None, # No need to specify a function here | |
inputs=[image_upload, question_input, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], | |
outputs=[chatbot_output] | |
) | |
demo.launch(share=True) |