torettomarui commited on
Commit
ce191fc
·
verified ·
1 Parent(s): 6e4bedb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -42
app.py CHANGED
@@ -1,54 +1,44 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
- import torch
4
- import torchvision.transforms as T
5
- from torchvision.transforms.functional import InterpolationMode
6
- from Models.modeling_llavaqw import LlavaQwModel
7
-
8
-
9
-
10
-
11
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
12
- IMAGENET_STD = (0.229, 0.224, 0.225)
13
 
 
 
14
 
15
  model_name = "torettomarui/Llava-qw"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
17
- model = LlavaQwModel.from_pretrained(
18
- model_name,
19
- torch_dtype=torch.bfloat16,
20
- trust_remote_code=True,
21
- ).to(torch.bfloat16).eval()#.cuda()
22
-
23
- def build_transform(input_size):
24
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
25
- transform = T.Compose([
26
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
27
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
28
- T.ToTensor(),
29
- T.Normalize(mean=MEAN, std=STD)
30
- ])
31
- return transform
32
-
33
-
34
- def preprocess_image(file_path, image_size=448):
35
-
36
- transform = build_transform(image_size)
37
- pixel_values = transform(file_path)
38
- return torch.stack([pixel_values]).to(torch.bfloat16)#.cuda()
39
 
40
  def generate_response(image, text):
41
-
42
- pixel_values = preprocess_image(image)
43
-
44
- generation_config = dict(max_new_tokens=2048, do_sample=False)
45
-
46
- question = '<image>\n' + text
47
-
48
- response = model.chat(tokenizer, pixel_values, question, generation_config)
49
-
50
  return response
51
 
 
52
  examples = [
53
  ["./text.png", "图中的文字是什么?"],
54
  ]
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ from Models.modeling_llavaqw import LlavaQwModel
 
 
 
 
 
 
7
 
8
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
9
+ IMAGENET_STD = (0.229, 0.224, 0.225)
10
 
11
  model_name = "torettomarui/Llava-qw"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
13
+ model = LlavaQwModel.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.bfloat16,
16
+ trust_remote_code=True,
17
+ ).to(torch.bfloat16).eval().cuda()
18
+
19
+ def build_transform(input_size):
20
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
21
+ transform = T.Compose([
22
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
23
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
24
+ T.ToTensor(),
25
+ T.Normalize(mean=MEAN, std=STD)
26
+ ])
27
+ return transform
28
+
29
+ def preprocess_image(file_path, image_size=448):
30
+ transform = build_transform(image_size)
31
+ pixel_values = transform(file_path)
32
+ return torch.stack([pixel_values]).to(torch.bfloat16).cuda()
 
 
33
 
34
  def generate_response(image, text):
35
+ pixel_values = preprocess_image(image)
36
+ generation_config = dict(max_new_tokens=2048, do_sample=False)
37
+ question = '<image>\n' + text
38
+ response = model.chat(tokenizer, pixel_values, question, generation_config)
 
 
 
 
 
39
  return response
40
 
41
+ # 添加示例图像和文本
42
  examples = [
43
  ["./text.png", "图中的文字是什么?"],
44
  ]