TobyYang7 commited on
Commit
ea37c27
1 Parent(s): ed5a7bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -31
app.py CHANGED
@@ -1,22 +1,22 @@
1
- import gradio as gr
 
2
  from llava_llama3.serve.cli import chat_llava
3
  from llava_llama3.model.builder import load_pretrained_model
4
- from PIL import Image
5
  import torch
 
 
6
  import spaces
7
 
8
  # Model configuration
9
- model_path = "TheFinAI/FinLLaVA"
10
- device = "cuda"
11
- conv_mode = "llama_3"
12
- temperature = 0
13
- max_new_tokens = 512
14
  load_8bit = False
15
  load_4bit = False
16
 
17
  # Load the pretrained model
18
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
19
- model_path,
20
  None,
21
  'llava_llama3',
22
  load_8bit,
@@ -24,38 +24,67 @@ tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
24
  device=device
25
  )
26
 
27
- # Define the prediction function
28
  @spaces.GPU
29
- def bot_streaming(image, text, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  output = chat_llava(
31
  args=None,
32
  image_file=image,
33
- text=text,
34
  tokenizer=tokenizer,
35
  model=llava_model,
36
  image_processor=image_processor,
37
  context_len=context_len
38
  )
39
- history.append((text, output))
40
- return history, gr.update(value="")
41
-
42
- # Create the Gradio interface
43
- with gr.Blocks() as demo:
44
- chatbot = gr.Chatbot(label="FinLLaVA Chatbot")
45
- image_input = gr.Image(type="filepath", label="Upload Image")
46
- text_input = gr.Textbox(label="Enter your message")
47
- submit_btn = gr.Button("Submit")
48
-
49
- # Define interaction: when submit is clicked, call bot_streaming and update the chatbot
50
- submit_btn.click(fn=bot_streaming, inputs=[image_input, text_input, chatbot], outputs=[chatbot, text_input])
51
-
52
- # Add example inputs
53
- gr.Examples(
54
- examples=[["./bee.jpg", "What is on the flower?"],
55
- ["./baklava.png", "How to make this pastry?"]],
56
- inputs=[image_input, text_input]
 
 
 
57
  )
58
 
59
- # Launch the Gradio app
60
  demo.queue(api_open=False)
61
- demo.launch(show_api=False, share=False)
 
1
+ import time
2
+ from threading import Thread
3
  from llava_llama3.serve.cli import chat_llava
4
  from llava_llama3.model.builder import load_pretrained_model
5
+ import gradio as gr
6
  import torch
7
+ from PIL import Image
8
+
9
  import spaces
10
 
11
  # Model configuration
12
+ model_id = "TheFinAI/FinLLaVA"
13
+ device = "cuda:0"
 
 
 
14
  load_8bit = False
15
  load_4bit = False
16
 
17
  # Load the pretrained model
18
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
19
+ model_id,
20
  None,
21
  'llava_llama3',
22
  load_8bit,
 
24
  device=device
25
  )
26
 
27
+
28
  @spaces.GPU
29
+ def bot_streaming(message, history):
30
+ print(message)
31
+ image = None
32
+
33
+ # Check if there's an image in the current message
34
+ if message["files"]:
35
+ # message["files"][-1] could be a dictionary or a string
36
+ if isinstance(message["files"][-1], dict):
37
+ image = message["files"][-1]["path"]
38
+ else:
39
+ image = message["files"][-1]
40
+ else:
41
+ # If no image in the current message, look in the history for the last image
42
+ for hist in history:
43
+ if isinstance(hist[0], tuple):
44
+ image = hist[0][0]
45
+
46
+ # Error handling if no image is found
47
+ if image is None:
48
+ raise gr.Error("You need to upload an image for LLaVA to work.")
49
+
50
+ # Load the image
51
+ image = Image.open(image)
52
+
53
+ # Generate the prompt for the model
54
+ prompt = message['text']
55
+
56
+ # Call the chat_llava function to generate the output
57
  output = chat_llava(
58
  args=None,
59
  image_file=image,
60
+ text=prompt,
61
  tokenizer=tokenizer,
62
  model=llava_model,
63
  image_processor=image_processor,
64
  context_len=context_len
65
  )
66
+
67
+ # Stream the output
68
+ buffer = ""
69
+ for new_text in output:
70
+ buffer += new_text
71
+ yield buffer
72
+
73
+
74
+ chatbot=gr.Chatbot(scale=1)
75
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
76
+ with gr.Blocks(fill_height=True, ) as demo:
77
+ gr.ChatInterface(
78
+ fn=bot_streaming,
79
+ title="LLaVA Llama-3-8B",
80
+ examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
81
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
82
+
83
+ stop_btn="Stop Generation",
84
+ multimodal=True,
85
+ textbox=chat_input,
86
+ chatbot=chatbot,
87
  )
88
 
 
89
  demo.queue(api_open=False)
90
+ demo.launch(show_api=False, share=False)