fastapi-demo / app.py
ka1kuk's picture
Update app.py
8f39c3d verified
raw
history blame
4.49 kB
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import traceback
model_name = "unsloth/llama-3-8b-bnb-4bit"
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize BitsAndBytesConfig if needed
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
# Load model with the correct configuration
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
model.eval()
ERROR_MSG = "An error occurred, please try again."
@spaces.GPU
def chat(img, msgs, params=None):
default_params = {"stream": False, "sampling": False, "num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
if params is None:
params = default_params
if img is None:
return "Error, invalid image, please upload a new image"
else:
try:
image = img.convert('RGB')
# Convert image to float32 tensor
image_tensor = torch.tensor(np.array(image)).float().div(255).unsqueeze(0).permute(0, 3, 1, 2).to(device='cuda')
answer = model.chat(
image=image_tensor,
msgs=msgs,
tokenizer=tokenizer,
**params
)
generated_text = ""
for char in answer:
generated_text += char
return generated_text
except Exception as err:
print(err)
traceback.print_exc()
return ERROR_MSG
def upload_img(image, _app_session):
image = Image.fromarray(image)
_app_session['sts'] = None
_app_session['ctx'] = []
_app_session['img'] = image
return 'Image uploaded successfully, you can talk to me now', _app_session
def respond(question, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if _app_cfg.get('img', None) is None:
return 'Please upload an image to start', _app_cfg
else:
_context = _app_cfg['ctx'].copy()
if _context:
_context.append({"role": "user", "content": question})
else:
_context = [{"role": "user", "content": question}]
if params_form == 'Beam Search':
params = {
'sampling': False,
'stream': False,
'num_beams': num_beams,
'repetition_penalty': repetition_penalty,
"max_new_tokens": 896
}
else:
params = {
'sampling': True,
'stream': True,
'top_p': top_p,
'top_k': top_k,
'temperature': temperature,
'repetition_penalty': repetition_penalty_2,
"max_new_tokens": 896
}
response = chat(_app_cfg['img'], _context, params)
_app_cfg['ctx'] = _context
return response, _app_cfg
app_cfg = {}
# Define the Gradio interface
image_upload = gr.Image(type="numpy", label="Upload an image")
question_input = gr.Textbox(label="Enter your question")
params_form = gr.Radio(choices=["Beam Search", "Sampling"], label="Parameters Form", value="Beam Search")
num_beams = gr.Slider(2, 10, step=1, value=3, label="Number of Beams")
repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
repetition_penalty_2 = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty (Sampling)")
top_p = gr.Slider(0.5, 1.0, step=0.1, value=0.9, label="Top P")
top_k = gr.Slider(10, 50, step=5, value=40, label="Top K")
temperature = gr.Slider(0.5, 1.0, step=0.1, value=0.7, label="Temperature")
chatbot_output = gr.Textbox(label="Response")
app_cfg_store = gr.State(app_cfg)
upload_button = gr.Button("Upload Image")
upload_button.click(upload_img, inputs=[image_upload, app_cfg_store], outputs=[chatbot_output, app_cfg_store])
respond_button = gr.Button("Send Question")
respond_button.click(
respond,
inputs=[question_input, app_cfg_store, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
outputs=[chatbot_output, app_cfg_store]
)
demo = gr.Interface(
fn=None, # No need to specify a function here
inputs=[image_upload, question_input, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
outputs=[chatbot_output]
)
demo.launch(share=True)