sooh-j commited on
Commit
79bde87
·
verified ·
1 Parent(s): 50da7fb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -5
handler.py CHANGED
@@ -40,10 +40,7 @@ class EndpointHandler():
40
  imageBase64 = inputs.get("image")
41
  question = inputs.get("question")
42
 
43
- # imageURL = inputs.get("image")
44
- # image = Image.open(requests.get(imageBase64, stream=True).raw)
45
-
46
- if 'http:' in imageBase64:
47
  image = Image.open(requests.get(imageBase64, stream=True).raw)
48
  else:
49
  image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
@@ -52,7 +49,12 @@ class EndpointHandler():
52
  processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
53
 
54
  with torch.no_grad():
55
- out = self.model.generate(**processed, max_new_tokens=512).to(self.device)
 
 
 
 
 
56
 
57
  result = {}
58
  text_output = self.processor.decode(out[0], skip_special_tokens=True)
 
40
  imageBase64 = inputs.get("image")
41
  question = inputs.get("question")
42
 
43
+ if ('http:' in imageBase64) or ('https:' in imageBase64):
 
 
 
44
  image = Image.open(requests.get(imageBase64, stream=True).raw)
45
  else:
46
  image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
 
49
  processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
50
 
51
  with torch.no_grad():
52
+ out = self.model.generate(**processed,
53
+ max_new_tokens=512,
54
+ temperature = 0.1,
55
+ do_sample=True,
56
+ repetition_penalty=1.2
57
+ ).to(self.device)
58
 
59
  result = {}
60
  text_output = self.processor.decode(out[0], skip_special_tokens=True)