Update Garage/AugmenterPipeline.py
Browse files
Garage/AugmenterPipeline.py
CHANGED
@@ -97,6 +97,8 @@ class Augmenter:
|
|
97 |
if mask.mode != 'L':
|
98 |
mask = mask.convert('L')
|
99 |
|
|
|
|
|
100 |
image_description = self._models["LLaVA"].generate_image_description(image)
|
101 |
prompt, new_object = self._models["LLaMA"].generate_prompt(current_object, image_description, new_objects_list)
|
102 |
# prompt = "cat"
|
|
|
97 |
if mask.mode != 'L':
|
98 |
mask = mask.convert('L')
|
99 |
|
100 |
+
self.to("cuda")
|
101 |
+
|
102 |
image_description = self._models["LLaVA"].generate_image_description(image)
|
103 |
prompt, new_object = self._models["LLaMA"].generate_prompt(current_object, image_description, new_objects_list)
|
104 |
# prompt = "cat"
|