anvilarth commited on
Commit
b78a609
1 Parent(s): 9c51397

Update Garage/AugmenterPipeline.py

Browse files
Files changed (1) hide show
  1. Garage/AugmenterPipeline.py +2 -0
Garage/AugmenterPipeline.py CHANGED
@@ -97,6 +97,8 @@ class Augmenter:
97
  if mask.mode != 'L':
98
  mask = mask.convert('L')
99
 
 
 
100
  image_description = self._models["LLaVA"].generate_image_description(image)
101
  prompt, new_object = self._models["LLaMA"].generate_prompt(current_object, image_description, new_objects_list)
102
  # prompt = "cat"
 
97
  if mask.mode != 'L':
98
  mask = mask.convert('L')
99
 
100
+ self.to("cuda")
101
+
102
  image_description = self._models["LLaVA"].generate_image_description(image)
103
  prompt, new_object = self._models["LLaMA"].generate_prompt(current_object, image_description, new_objects_list)
104
  # prompt = "cat"