import spaces import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import traceback model_name = "unsloth/llama-3-8b-bnb-4bit" # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # Initialize BitsAndBytesConfig if needed bnb_config = BitsAndBytesConfig(load_in_4bit=True) # Load model with the correct configuration model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config) model.eval() ERROR_MSG = "An error occurred, please try again." @spaces.GPU def chat(img, msgs, params=None): default_params = {"stream": False, "sampling": False, "num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024} if params is None: params = default_params if img is None: return "Error, invalid image, please upload a new image" else: try: image = img.convert('RGB') # Convert image to float32 tensor image_tensor = torch.tensor(np.array(image)).float().div(255).unsqueeze(0).permute(0, 3, 1, 2).to(device='cuda') answer = model.chat( image=image_tensor, msgs=msgs, tokenizer=tokenizer, **params ) generated_text = "" for char in answer: generated_text += char return generated_text except Exception as err: print(err) traceback.print_exc() return ERROR_MSG def upload_img(image, _app_session): image = Image.fromarray(image) _app_session['sts'] = None _app_session['ctx'] = [] _app_session['img'] = image return 'Image uploaded successfully, you can talk to me now', _app_session def respond(question, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): if _app_cfg.get('img', None) is None: return 'Please upload an image to start', _app_cfg else: _context = _app_cfg['ctx'].copy() if _context: _context.append({"role": "user", "content": question}) else: _context = [{"role": "user", "content": question}] if params_form == 'Beam Search': params = { 'sampling': False, 'stream': False, 'num_beams': num_beams, 'repetition_penalty': repetition_penalty, "max_new_tokens": 896 } else: params = { 'sampling': True, 'stream': True, 'top_p': top_p, 'top_k': top_k, 'temperature': temperature, 'repetition_penalty': repetition_penalty_2, "max_new_tokens": 896 } response = chat(_app_cfg['img'], _context, params) _app_cfg['ctx'] = _context return response, _app_cfg app_cfg = {} # Define the Gradio interface image_upload = gr.Image(type="numpy", label="Upload an image") question_input = gr.Textbox(label="Enter your question") params_form = gr.Radio(choices=["Beam Search", "Sampling"], label="Parameters Form", value="Beam Search") num_beams = gr.Slider(2, 10, step=1, value=3, label="Number of Beams") repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") repetition_penalty_2 = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty (Sampling)") top_p = gr.Slider(0.5, 1.0, step=0.1, value=0.9, label="Top P") top_k = gr.Slider(10, 50, step=5, value=40, label="Top K") temperature = gr.Slider(0.5, 1.0, step=0.1, value=0.7, label="Temperature") chatbot_output = gr.Textbox(label="Response") app_cfg_store = gr.State(app_cfg) upload_button = gr.Button("Upload Image") upload_button.click(upload_img, inputs=[image_upload, app_cfg_store], outputs=[chatbot_output, app_cfg_store]) respond_button = gr.Button("Send Question") respond_button.click( respond, inputs=[question_input, app_cfg_store, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], outputs=[chatbot_output, app_cfg_store] ) demo = gr.Interface( fn=None, # No need to specify a function here inputs=[image_upload, question_input, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], outputs=[chatbot_output] ) demo.launch(share=True)