TobyYang7 commited on
Commit
03d6908
1 Parent(s): 50504f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -4,6 +4,7 @@ from llava_llama3.model.builder import load_pretrained_model
4
  from PIL import Image
5
  import torch
6
 
 
7
  model_path = "TheFinAI/FinLLaVA"
8
  device = "cuda"
9
  conv_mode = "llama_3"
@@ -12,6 +13,7 @@ max_new_tokens = 512
12
  load_8bit = False
13
  load_4bit = False
14
 
 
15
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
16
  model_path,
17
  None,
@@ -21,7 +23,8 @@ tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
21
  device=device
22
  )
23
 
24
- def predict(image, text):
 
25
  output = chat_llava(
26
  args=None,
27
  image_file=image,
@@ -31,23 +34,26 @@ def predict(image, text):
31
  image_processor=image_processor,
32
  context_len=context_len
33
  )
34
- return output
 
35
 
 
 
 
 
 
 
36
 
37
- chatbot = gr.Chatbot(scale=1)
38
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
39
 
40
- with gr.Blocks(fill_height=True) as demo:
41
- gr.ChatInterface(
42
- fn=predict,
43
- title="FinLLaVA",
44
- examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
45
- {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
46
- stop_btn="Stop Generation",
47
- multimodal=True,
48
- textbox=chat_input,
49
- chatbot=chatbot,
50
  )
51
 
 
52
  demo.queue(api_open=False)
53
  demo.launch(show_api=False, share=False)
 
4
  from PIL import Image
5
  import torch
6
 
7
+ # Model configuration
8
  model_path = "TheFinAI/FinLLaVA"
9
  device = "cuda"
10
  conv_mode = "llama_3"
 
13
  load_8bit = False
14
  load_4bit = False
15
 
16
+ # Load the pretrained model
17
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
18
  model_path,
19
  None,
 
23
  device=device
24
  )
25
 
26
+ # Define the prediction function
27
+ def predict(image, text, history):
28
  output = chat_llava(
29
  args=None,
30
  image_file=image,
 
34
  image_processor=image_processor,
35
  context_len=context_len
36
  )
37
+ history.append((text, output))
38
+ return history, gr.update(value="")
39
 
40
+ # Create the Gradio interface
41
+ with gr.Blocks() as demo:
42
+ chatbot = gr.Chatbot(label="FinLLaVA Chatbot")
43
+ image_input = gr.Image(type="filepath", label="Upload Image")
44
+ text_input = gr.Textbox(label="Enter your message")
45
+ submit_btn = gr.Button("Submit")
46
 
47
+ # Define interaction: when submit is clicked, call predict and update the chatbot
48
+ submit_btn.click(fn=predict, inputs=[image_input, text_input, chatbot], outputs=[chatbot, text_input])
49
 
50
+ # Add example inputs
51
+ gr.Examples(
52
+ examples=[["./bee.jpg", "What is on the flower?"],
53
+ ["./baklava.png", "How to make this pastry?"]],
54
+ inputs=[image_input, text_input]
 
 
 
 
 
55
  )
56
 
57
+ # Launch the Gradio app
58
  demo.queue(api_open=False)
59
  demo.launch(show_api=False, share=False)