MaziyarPanahi commited on
Commit
6c67d55
1 Parent(s): de7bed9
Files changed (1) hide show
  1. app.py +28 -33
app.py CHANGED
@@ -11,23 +11,18 @@ import torch
11
  import spaces
12
  import requests
13
 
14
- CSS ="""
15
- .container { display: flex; flex-direction: column; height: 500px; }
16
- #chatbot { flex-grow: 1; }
17
- """
18
-
19
  model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
20
 
21
- # processor = AutoProcessor.from_pretrained(model_id)
22
 
23
- # model = LlavaForConditionalGeneration.from_pretrained(
24
- # model_id,
25
- # torch_dtype=torch.float16,
26
- # low_cpu_mem_usage=True,
27
- # )
28
 
29
- # model.to("cuda:0")
30
- # model.generation_config.eos_token_id = 128009
31
 
32
  @spaces.GPU
33
  def bot_streaming(message, history):
@@ -41,34 +36,34 @@ def bot_streaming(message, history):
41
  if type(hist[0])==tuple:
42
  image = hist[0][0]
43
 
44
- # if image is None:
45
- # gr.Error("You need to upload an image for LLaVA to work.")
46
- # prompt=f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
47
- # print(f"prompt: {prompt}")
48
- # image = Image.open(image)
49
- # inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
50
 
51
- # streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
52
- # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
53
- # generated_text = ""
54
 
55
- # thread = Thread(target=model.generate, kwargs=generation_kwargs)
56
- # thread.start()
57
 
58
- # text_prompt =f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
59
- # print(f"text_prompt: {text_prompt}")
60
 
61
- # buffer = ""
62
- # for new_text in streamer:
63
 
64
- # buffer += new_text
65
 
66
- # generated_text_without_prompt = buffer[len(text_prompt):]
67
- # time.sleep(0.04)
68
- # yield generated_text_without_prompt
69
 
70
 
71
- with gr.Blocks(css=CSS) as demo:
72
  chatbot = gr.ChatInterface(fn=bot_streaming, title="LLaVA Llama-3-8B", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
73
  {"text": "How to make this pastry?", "files":["./baklava.png"]}],
74
  description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
 
11
  import spaces
12
  import requests
13
 
 
 
 
 
 
14
  model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
15
 
16
+ processor = AutoProcessor.from_pretrained(model_id)
17
 
18
+ model = LlavaForConditionalGeneration.from_pretrained(
19
+ model_id,
20
+ torch_dtype=torch.float16,
21
+ low_cpu_mem_usage=True,
22
+ )
23
 
24
+ model.to("cuda:0")
25
+ model.generation_config.eos_token_id = 128009
26
 
27
  @spaces.GPU
28
  def bot_streaming(message, history):
 
36
  if type(hist[0])==tuple:
37
  image = hist[0][0]
38
 
39
+ if image is None:
40
+ gr.Error("You need to upload an image for LLaVA to work.")
41
+ prompt=f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
42
+ print(f"prompt: {prompt}")
43
+ image = Image.open(image)
44
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
45
 
46
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
47
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
48
+ generated_text = ""
49
 
50
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
51
+ thread.start()
52
 
53
+ text_prompt =f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
54
+ print(f"text_prompt: {text_prompt}")
55
 
56
+ buffer = ""
57
+ for new_text in streamer:
58
 
59
+ buffer += new_text
60
 
61
+ generated_text_without_prompt = buffer[len(text_prompt):]
62
+ time.sleep(0.04)
63
+ yield generated_text_without_prompt
64
 
65
 
66
+ with gr.Blocks as demo:
67
  chatbot = gr.ChatInterface(fn=bot_streaming, title="LLaVA Llama-3-8B", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
68
  {"text": "How to make this pastry?", "files":["./baklava.png"]}],
69
  description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",