hantech commited on
Commit
74f3fbc
β€’
1 Parent(s): bd22b5e

Delete vietocr/translate.py

Browse files
Files changed (1) hide show
  1. vietocr/translate.py +0 -62
vietocr/translate.py DELETED
@@ -1,62 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import math
4
- from PIL import Image
5
- from torch.nn.functional import softmax
6
-
7
- def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
8
- "data: BxCXHxW"
9
- model.eval()
10
-
11
- with torch.no_grad():
12
- src = model.cnn(img)
13
- memory = model.transformer.forward_encoder(src)
14
-
15
- translated_sentence = [[sos_token]*len(img)]
16
-
17
- max_length = 0
18
-
19
- while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):
20
- tgt_inp = torch.LongTensor(translated_sentence)
21
-
22
- output, memory = model.transformer.forward_decoder(tgt_inp, memory)
23
- output = softmax(output, dim=-1)
24
-
25
- _, indices = torch.topk(output, 5)
26
-
27
- indices = indices[:, -1, 0]
28
- indices = indices.tolist()
29
-
30
- translated_sentence.append(indices)
31
- max_length += 1
32
-
33
- translated_sentence = np.asarray(translated_sentence).T
34
-
35
- return translated_sentence
36
-
37
- def resize(w, h, expected_height, image_min_width, image_max_width):
38
- new_w = int(expected_height * float(w) / float(h))
39
- round_to = 10
40
- new_w = math.ceil(new_w/round_to)*round_to
41
- new_w = max(new_w, image_min_width)
42
- new_w = min(new_w, image_max_width)
43
-
44
- return new_w, expected_height
45
-
46
- def process_image(image, image_height, image_min_width, image_max_width):
47
- img = image.convert('RGB')
48
-
49
- w, h = img.size
50
- new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
51
-
52
- img = img.resize((new_w, image_height), Image.Resampling.LANCZOS)
53
-
54
- img = np.asarray(img).transpose(2,0, 1)
55
- img = img/255
56
- return img
57
-
58
- def process_input(image, image_height, image_min_width, image_max_width):
59
- img = process_image(image, image_height, image_min_width, image_max_width)
60
- img = img[np.newaxis, ...]
61
- img = torch.FloatTensor(img)
62
- return img