lindsay-qu commited on
Commit
3ecbc1a
1 Parent(s): ea54126

Update core/refiner/simple_refiner.py

Browse files
Files changed (1) hide show
  1. core/refiner/simple_refiner.py +9 -6
core/refiner/simple_refiner.py CHANGED
@@ -8,20 +8,23 @@ class SimpleRefiner(BaseRefiner):
8
  ) -> None:
9
  BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
10
 
11
- def refine(self, message: str, memory, image_path=None) -> str:
12
  if memory is None:
13
  memory = []
14
  else:
15
  memory = memory.messages[1:]
16
- if image_path:
 
 
 
 
 
 
17
  context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
18
  {"type": "text", "text": f"{message}"},
19
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}},
20
  ]}]
21
  else:
22
- context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
23
- {"type": "text", "text": f"{message}"},
24
- ]}]
25
  response = self.model.respond(context)
26
 
27
  return response
 
8
  ) -> None:
9
  BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
10
 
11
+ def refine(self, message: str, memory, image_paths=None) -> str:
12
  if memory is None:
13
  memory = []
14
  else:
15
  memory = memory.messages[1:]
16
+ if image_paths:
17
+ if not isinstance(image_paths, list):
18
+ image_paths = [image_paths]
19
+ user_context = [{"role": "user", "content": [
20
+ {"type": "text", "text": f"{message}"},]}]
21
+ for image_path in image_paths:
22
+ user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
23
  context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
24
  {"type": "text", "text": f"{message}"},
 
25
  ]}]
26
  else:
27
+ context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
 
 
28
  response = self.model.respond(context)
29
 
30
  return response