baixintech_zhangyiming_prod
import components
1c962af
raw
history blame contribute delete
865 Bytes
import gradio as gr
import gradio.components as grc
from torchvision import transforms
from transformers import ViTForImageClassification
model_path = "Inf009/view-angle"
model = ViTForImageClassification.from_pretrained(model_path)
model.eval()
val_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)
def predict_view_angle(image):
image = val_transforms(image)
outputs = model(image.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
indices = sorted(range(len(outputs)), key=lambda x: outputs[x], reverse=True)
predict_tags = ["45度俯视", "俯视", "正视"]
return predict_tags[indices[0]]
app = gr.Interface(fn=predict_view_angle, inputs=grc.Image(type="pil"), outputs=grc.Textbox())
app.launch()