torettomarui commited on
Commit
31264bb
·
verified ·
1 Parent(s): acc42f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -14,7 +14,7 @@ model = LlavaQwModel.from_pretrained(
14
  model_name,
15
  torch_dtype=torch.bfloat16,
16
  trust_remote_code=True,
17
- ).to(torch.bfloat16).eval()#.cuda()
18
 
19
  def build_transform(input_size):
20
  MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
@@ -29,7 +29,7 @@ def build_transform(input_size):
29
  def preprocess_image(file_path, image_size=448):
30
  transform = build_transform(image_size)
31
  pixel_values = transform(file_path)
32
- return torch.stack([pixel_values]).to(torch.bfloat16)#.cuda()
33
 
34
  def generate_response(image, text):
35
  pixel_values = preprocess_image(image)
 
14
  model_name,
15
  torch_dtype=torch.bfloat16,
16
  trust_remote_code=True,
17
+ ).to(torch.bfloat16).eval().cuda()
18
 
19
  def build_transform(input_size):
20
  MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
 
29
  def preprocess_image(file_path, image_size=448):
30
  transform = build_transform(image_size)
31
  pixel_values = transform(file_path)
32
+ return torch.stack([pixel_values]).to(torch.bfloat16).cuda()
33
 
34
  def generate_response(image, text):
35
  pixel_values = preprocess_image(image)