Gilvan commited on
Commit
575baf4
1 Parent(s): 5659fa7
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel, GPT2Tokenizer, pipeline
4
+ import os
5
+
6
+ device = 'cpu'
7
+ access_token = os.getenv("auth_token")
8
+
9
+ max_length = 100
10
+ num_beams = 4
11
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
12
+ def predict_step(image_paths, model):
13
+ images = []
14
+ for image_path in image_paths:
15
+ i_image = Image.open(image_path)
16
+ if i_image.mode != "RGB":
17
+ i_image = i_image.convert(mode="RGB")
18
+
19
+ #i_image.resize((640, 480))
20
+
21
+ images.append(i_image)
22
+
23
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
24
+ pixel_values = pixel_values.to(device)
25
+
26
+ output_ids = model.generate(pixel_values, **gen_kwargs)
27
+
28
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
29
+ preds = [pred.strip() for pred in preds]
30
+ return preds
31
+
32
+ def predict_step_image(dataset_images, feature_extractor, model):
33
+ results = []
34
+ for i in dataset_images:
35
+ pixel_values = feature_extractor(images=i, return_tensors="pt").pixel_values
36
+ pixel_values = pixel_values.to(device)
37
+
38
+ output_ids = model.generate(pixel_values, **gen_kwargs)
39
+
40
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
41
+ preds = [pred.strip() for pred in preds]
42
+ results.append(preds)
43
+ return results
44
+
45
+ def predict_step_single_image(image, tokenizer, feature_extractor, model):
46
+ results=[]
47
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
48
+ pixel_values = pixel_values.to(device)
49
+
50
+ output_ids = model.generate(pixel_values, **gen_kwargs)
51
+
52
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
53
+ preds = [pred.strip() for pred in preds]
54
+ results.append(preds)
55
+ return results
56
+
57
+ def predict_step_pixel(dataset_pixel_values, model):
58
+ results=[]
59
+ for pv in dataset_pixel_values:
60
+ pixel_values = pv.reshape([1,3,224,224])
61
+ pixel_values = pixel_values.to(device)
62
+ output_ids = model.generate(pixel_values, **gen_kwargs)
63
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
64
+ results.append([pred.strip() for pred in preds][0])
65
+ return results
66
+
67
+ """
68
+ image methods
69
+ """
70
+ def load_image2txt_model(image_model_name):
71
+ model = VisionEncoderDecoderModel.from_pretrained(image_model_name)
72
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-large-patch4-window7-224")
73
+
74
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
75
+ tokenizer.pad_token = tokenizer.eos_token
76
+
77
+ model = model.to(device)
78
+ return tokenizer, feature_extractor, model
79
+
80
+ def inference_image_pipe(image_input):
81
+ image_model_name = "./checkpoint-21000"
82
+
83
+ tokenizer, feature_extractor, image_model = load_image2txt_model(image_model_name)
84
+ #with autocast('cpu'):
85
+ text = predict_step_single_image(image_input, tokenizer, feature_extractor, image_model)[0]
86
+ return text
87
+
88
+ with gr.Interface(fn=inference_image_pipe,
89
+ inputs=gr.Image(shape=(256, 256)),
90
+ outputs="text",
91
+ examples=["3212210S4492629-1.png", "3216497S4499373-1.png"]) as demo:
92
+ gr.Markdown("POC V0 - XRay Automatic Medical Report")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch(share=True)