TobyYang7 commited on
Commit
afff347
1 Parent(s): daee25b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -73
app.py CHANGED
@@ -1,20 +1,37 @@
 
1
  from threading import Thread
2
- from llava_llama3.serve.cli import chat_llava
3
- from llava_llama3.model.builder import load_pretrained_model
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
 
 
 
7
  import argparse
8
- import spaces
 
 
 
 
 
 
 
 
 
 
9
  import os
10
- import time
 
 
 
11
 
12
  root_path = os.path.dirname(os.path.abspath(__file__))
13
- print(root_path)
 
14
 
15
  parser = argparse.ArgumentParser()
16
- parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
17
- parser.add_argument("--device", type=str, default="cuda:0")
18
  parser.add_argument("--conv-mode", type=str, default="llama_3")
19
  parser.add_argument("--temperature", type=float, default=0.7)
20
  parser.add_argument("--max-new-tokens", type=int, default=512)
@@ -22,87 +39,56 @@ parser.add_argument("--load-8bit", action="store_true")
22
  parser.add_argument("--load-4bit", action="store_true")
23
  args = parser.parse_args()
24
 
25
- # load model
26
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
27
  args.model_path,
28
  None,
29
  'llava_llama3',
30
  args.load_8bit,
31
  args.load_4bit,
32
- device=args.device
33
- )
34
 
35
- @spaces.GPU
36
  def bot_streaming(message, history):
37
  print(message)
38
- image_path = None
39
-
40
- # Check if there's an image in the current message
41
  if message["files"]:
42
- # message["files"][-1] could be a dictionary or a string
43
- if isinstance(message["files"][-1], dict):
44
- image_path = message["files"][-1]["path"]
45
  else:
46
- image_path = message["files"][-1]
47
  else:
48
- # If no image in the current message, look in the history for the last image path
49
  for hist in history:
50
- if isinstance(hist[0], tuple):
51
- image_path = hist[0][0]
52
-
53
- # Error handling if no image path is found
54
- if image_path is None:
55
- raise gr.Error("You need to upload an image for LLaVA to work.")
56
 
57
- # If the image_path is a string, no need to load it into a PIL image
58
- # Just use the path directly in the next steps
59
- print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
60
-
61
- # Generate the prompt for the model
62
- prompt = message['text']
63
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
-
65
- # Set up the generation arguments, including the streamer
66
- generation_kwargs = dict(
67
- args=args,
68
- image_file=image_path,
69
- text=prompt,
70
- tokenizer=tokenizer,
71
- model=llava_model,
72
- streamer=streamer,
73
- image_processor=image_processor, # todo: input model name or path
74
- context_len=context_len)
75
-
76
- # Define the function to call `chat_llava` with the given arguments
77
- def generate_output(generation_kwargs):
78
- chat_llava(**generation_kwargs)
79
-
80
- # Start the generation in a separate thread
81
- thread = Thread(target=generate_output, kwargs=generation_kwargs)
82
  thread.start()
83
-
84
- # Initialize a buffer to accumulate the generated text
85
  buffer = ""
86
-
87
- # Allow the generation to start
88
- time.sleep(0.5)
89
-
90
- # Iterate over the streamer to handle the incoming text in chunks
91
  for new_text in streamer:
92
- # Look for the end of text token and remove it
93
- if "<|eot_id|>" in new_text:
94
- new_text = new_text.split("<|eot_id|>")[0]
95
-
96
- # Add the new text to the buffer
97
  buffer += new_text
98
-
99
- # Remove the prompt from the generated text (if necessary)
100
- generated_text_without_prompt = buffer[len(prompt):]
101
-
102
- # Simulate processing time (optional)
103
  time.sleep(0.06)
104
-
105
- # Yield the current generated text for further processing or display
106
  yield generated_text_without_prompt
107
 
108
  chatbot = gr.Chatbot(scale=1)
@@ -110,10 +96,11 @@ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeh
110
  with gr.Blocks(fill_height=True) as demo:
111
  gr.ChatInterface(
112
  fn=bot_streaming,
113
- title="FinLLaVA",
114
- examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
115
- {"text": "How to make this pastry?", "files": ["./baklava.png"]},
116
- {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
 
117
  stop_btn="Stop Generation",
118
  multimodal=True,
119
  textbox=chat_input,
 
1
+ import time
2
  from threading import Thread
3
+
 
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer
8
+
9
+ # import spaces
10
  import argparse
11
+
12
+ from llava_llama3.model.builder import load_pretrained_model
13
+ from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
14
+ from llava_llama3.conversation import conv_templates, SeparatorStyle
15
+ from llava_llama3.utils import disable_torch_init
16
+ from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
17
+ from llava_llama3.serve.cli import chat_llava
18
+
19
+ import requests
20
+ from io import BytesIO
21
+ import base64
22
  import os
23
+ import glob
24
+ import pandas as pd
25
+ from tqdm import tqdm
26
+ import json
27
 
28
  root_path = os.path.dirname(os.path.abspath(__file__))
29
+ print(f'\033[92m{root_path}\033[0m')
30
+ os.environ['GRADIO_TEMP_DIR'] = root_path
31
 
32
  parser = argparse.ArgumentParser()
33
+ parser.add_argument("--model-path", type=str, default="/mnt/nvme1n1/toby/LLaVA/checkpoints/0806_onlyllava_llava-finma-8B-v0.4-v8/checkpoint-2000")
34
+ parser.add_argument("--device", type=str, default="cuda")
35
  parser.add_argument("--conv-mode", type=str, default="llama_3")
36
  parser.add_argument("--temperature", type=float, default=0.7)
37
  parser.add_argument("--max-new-tokens", type=int, default=512)
 
39
  parser.add_argument("--load-4bit", action="store_true")
40
  args = parser.parse_args()
41
 
42
+ # Load model
43
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
44
  args.model_path,
45
  None,
46
  'llava_llama3',
47
  args.load_8bit,
48
  args.load_4bit,
49
+ device=args.device)
 
50
 
 
51
  def bot_streaming(message, history):
52
  print(message)
53
+ image_file = None
 
 
54
  if message["files"]:
55
+ if type(message["files"][-1]) == dict:
56
+ image_file = message["files"][-1]["path"]
 
57
  else:
58
+ image_file = message["files"][-1]
59
  else:
 
60
  for hist in history:
61
+ if type(hist[0]) == tuple:
62
+ image_file = hist[0][0]
63
+
64
+ if image_file is None:
65
+ gr.Error("You need to upload an image for LLaVA to work.")
66
+ return
67
 
68
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
+ def generate():
70
+ print('\033[92mRunning chat\033[0m')
71
+ output = chat_llava(
72
+ args=args,
73
+ image_file=image_file,
74
+ text=message['text'],
75
+ tokenizer=tokenizer,
76
+ model=llava_model,
77
+ image_processor=image_processor,
78
+ context_len=context_len,
79
+ streamer=streamer)
80
+ return output
81
+
82
+ thread = Thread(target=generate)
 
 
 
 
 
 
 
 
 
 
83
  thread.start()
84
+ # thread.join()
85
+
86
  buffer = ""
87
+ # output = generate()
 
 
 
 
88
  for new_text in streamer:
 
 
 
 
 
89
  buffer += new_text
90
+ generated_text_without_prompt = buffer
 
 
 
 
91
  time.sleep(0.06)
 
 
92
  yield generated_text_without_prompt
93
 
94
  chatbot = gr.Chatbot(scale=1)
 
96
  with gr.Blocks(fill_height=True) as demo:
97
  gr.ChatInterface(
98
  fn=bot_streaming,
99
+ title="FinLLaVA Demo",
100
+ examples=[
101
+ {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
102
+ ],
103
+ description="",
104
  stop_btn="Stop Generation",
105
  multimodal=True,
106
  textbox=chat_input,