ConradLax commited on
Commit
53a3910
·
1 Parent(s): fdcdc66

test: classification model

Browse files
Files changed (1) hide show
  1. main.py +114 -5
main.py CHANGED
@@ -11,8 +11,9 @@ pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
11
 
12
  @app.get("/infer_t5")
13
  def t5(input):
14
- output = pipe_flan(input)
15
- return {"output": output[0]["generated_text"]}
 
16
 
17
 
18
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
@@ -22,6 +23,114 @@ def index() -> FileResponse:
22
  return FileResponse(path="/app/static/index.html", media_type="text/html")
23
 
24
 
25
- #@app.get("/")
26
- #def read_root():
27
- # return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @app.get("/infer_t5")
13
  def t5(input):
14
+ # output = pipe_flan(input)
15
+ # return {"output": output[0]["generated_text"]}
16
+ return classify_acct_dtype_str("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg")
17
 
18
 
19
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
23
  return FileResponse(path="/app/static/index.html", media_type="text/html")
24
 
25
 
26
+
27
+
28
+
29
+
30
+
31
+ # Doc classifier model
32
+ classifier_doctype_processor = DonutProcessor.from_pretrained("calumpianojericho/donutclassifier_acctdocs_by_doctype")
33
+ classifier_doctype_model = VisionEncoderDecoderModel.from_pretrained("calumpianojericho/donutclassifier_acctdocs_by_doctype")
34
+
35
+
36
+ """### Inference Code"""
37
+
38
+ def inference(input, model, processor, threshold=1.0, task_prompt="", get_confidence=False):
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ model.to(device)
41
+ is_confident = True
42
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
43
+
44
+ pil_img=input
45
+
46
+ image = np.array(pil_img)
47
+ pixel_values = processor(image, return_tensors="pt").pixel_values
48
+
49
+ outputs = model.generate(
50
+ pixel_values.to(device),
51
+ decoder_input_ids=decoder_input_ids.to(device),
52
+ max_length=model.decoder.config.max_position_embeddings,
53
+ early_stopping=True,
54
+ pad_token_id=processor.tokenizer.pad_token_id,
55
+ eos_token_id= processor.tokenizer.eos_token_id,
56
+ use_cache=True,
57
+ num_beams=1,
58
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
59
+ return_dict_in_generate=True,
60
+ output_scores=True,
61
+ )
62
+
63
+ sequence = processor.batch_decode(outputs.sequences)[0]
64
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
65
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
66
+
67
+ seq = processor.token2json(sequence)
68
+ if get_confidence:
69
+ return seq, pred_confidence(outputs.scores, threshold)
70
+
71
+ return seq
72
+
73
+ def pred_confidence(output_scores, threshold):
74
+ is_confident=True
75
+
76
+ for score in output_scores:
77
+ exp_scores = np.exp(score[0].cpu().numpy()) # scores are logits, we use the exp function so that all values are positive
78
+ sum_exp = np.sum(exp_scores) # taking the sum of the token scores
79
+ idx = np.argmax(exp_scores) # taking the index of the token with the highest score
80
+ prob_max = exp_scores[idx]/sum_exp # normalizing the token with the highest score wrt the sum of all scores. Returns probability
81
+ if prob_max < threshold:
82
+ is_confident = False
83
+ # print(prob_max)
84
+
85
+
86
+ return is_confident
87
+
88
+
89
+ CUDA_LAUNCH_BLOCKING=1
90
+ def parse_text(input, filename):
91
+ model = base_model
92
+ processor = base_processor
93
+ seq = inference(input, model, processor, task_prompt="<s_synthdog>")
94
+ return str(seq)
95
+
96
+ def doctype_classify(input, filename):
97
+ model = classifier_doctype_model
98
+ processor = classifier_doctype_processor
99
+ seq, is_confident = inference(input, model, processor, threshold=0.90, task_prompt="<s_classifier_acct>", get_confidence=True)
100
+ return seq.get('class'), is_confident
101
+
102
+ def account_classify(input, filename):
103
+ model = classifier_account_model
104
+ processor = classifier_account_processor
105
+ seq, is_confident = inference(input, model, processor, threshold=0.999, task_prompt="<s_classifier_acct>", get_confidence=True)
106
+ return seq.get('class'), is_confident
107
+
108
+ """## Text processing/string matcher code"""
109
+
110
+ import locale
111
+ locale.getpreferredencoding = lambda: "UTF-8"
112
+
113
+
114
+ """## Text processing/string matcher code"""
115
+
116
+ import locale
117
+ locale.getpreferredencoding = lambda: "UTF-8"
118
+
119
+
120
+ """## Classify Document Images"""
121
+
122
+ import numpy as np
123
+ import csv
124
+ import re
125
+ import os
126
+
127
+
128
+ import requests
129
+
130
+ def classify_acct_dtype_str(input_path):
131
+ response = requests.get("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg")
132
+ dtype_inf, dtype_conf = doctype_classify(response, "city-streets.jpg")
133
+
134
+ return dtype_inf
135
+
136
+ classify_acct_dtype_str("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg")