MS-YUN commited on
Commit
bc3b9d1
1 Parent(s): 85e92ff

트랜스포머"

Browse files
Files changed (1) hide show
  1. app.py +31 -4
app.py CHANGED
@@ -1,8 +1,35 @@
 
1
 
 
 
 
 
 
2
 
3
- def image_classifier(img):
4
- return {'cat': 0.3, 'dog': 0.7}
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
- demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label") # greet이라는 함수
8
- demo.launch(debug=True) # share=True : 다른 사람과 공유
 
 
 
 
 
 
1
+ # 모델로딩
2
 
3
+ # ImageNet-1k에 훈련된 모델과 특징 추출기 로드
4
+ from transformers import ViTImageProcessor, ViTForImageClassification
5
+ model_name = "google/vit-base-patch16-224"
6
+ model = ViTForImageClassification.from_pretrained(model_name)
7
+ image_processor = ViTImageProcessor.from_pretrained(model_name)
8
 
9
+ # 이미지 예측 분류함수
 
10
 
11
+ import torch
12
+ def classify_image(inp):
13
+ # 이미지를 특징 벡터로 변환
14
+ inputs = image_processor(images=inp, return_tensors="pt")
15
+ pixel_values = inputs["pixel_values"]
16
+
17
+ # 예측 수행
18
+ outputs = model(pixel_values)
19
+ logits = outputs.logits
20
+ predicted_index = torch.argmax(logits, 1)[0].item()
21
+
22
+ # 가장 확률이 높은 라벨 반환``
23
+ label = model.config.id2label[predicted_index]
24
+ return label
25
+
26
+ # Gradio 인터페이스 설정
27
+ from PIL import Image
28
  import gradio as gr
29
+ interface = gr.Interface(
30
+ fn=classify_image,
31
+ inputs=gr.components.Image(type="pil", label="Upload an Image"),
32
+ outputs="text",
33
+ live=True
34
+ )
35
+ interface.launch()