Update vit_model_test.py
Browse files- vit_model_test.py +13 -25
 
    	
        vit_model_test.py
    CHANGED
    
    | 
         @@ -3,39 +3,26 @@ import torch.nn as nn 
     | 
|
| 3 | 
         
             
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 4 | 
         
             
            from torchvision import transforms
         
     | 
| 5 | 
         
             
            from transformers import ViTForImageClassification
         
     | 
| 6 | 
         
            -
            from PIL import Image
         
     | 
| 7 | 
         
             
            import os
         
     | 
| 8 | 
         
             
            import pandas as pd
         
     | 
| 9 | 
         
             
            from sklearn.model_selection import train_test_split
         
     | 
| 10 | 
         
             
            from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
         
     | 
| 11 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 12 | 
         
             
            import seaborn as sns
         
     | 
| 13 | 
         
            -
            import cv2  # 住驻专讬讬转 OpenCV 诇讛爪讙转 讛讜讬讚讗讜
         
     | 
| 14 | 
         
            -
            from vit_model_traning import labeling, CustomDataset
         
     | 
| 15 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
         
     | 
| 18 | 
         
             
                shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
         
     | 
| 19 | 
         
             
                train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
         
     | 
| 20 | 
         
             
                return train_df, val_df    
         
     | 
| 21 | 
         | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
            def play_animation(video_path):
         
     | 
| 24 | 
         
            -
                cap = cv2.VideoCapture(video_path)
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                while cap.isOpened():
         
     | 
| 27 | 
         
            -
                    ret, frame = cap.read()
         
     | 
| 28 | 
         
            -
                    if not ret:
         
     | 
| 29 | 
         
            -
                        break
         
     | 
| 30 | 
         
            -
                    cv2.imshow('Processing Animation', frame)
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
                    # Press 'q' to exit early
         
     | 
| 33 | 
         
            -
                    if cv2.waitKey(25) & 0xFF == ord('q'):
         
     | 
| 34 | 
         
            -
                        break
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
                cap.release()
         
     | 
| 37 | 
         
            -
                cv2.destroyAllWindows()
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
             
            if __name__ == "__main__":
         
     | 
| 40 | 
         
             
                # Check for GPU availability
         
     | 
| 41 | 
         
             
                device = torch.device('cuda')
         
     | 
| 
         @@ -67,8 +54,12 @@ if __name__ == "__main__": 
     | 
|
| 67 | 
         
             
                true_labels = []
         
     | 
| 68 | 
         
             
                predicted_labels = []
         
     | 
| 69 | 
         | 
| 70 | 
         
            -
                #  
     | 
| 71 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 72 | 
         | 
| 73 | 
         
             
                with torch.no_grad():
         
     | 
| 74 | 
         
             
                    for images, labels in test_loader:
         
     | 
| 
         @@ -100,6 +91,3 @@ if __name__ == "__main__": 
     | 
|
| 100 | 
         
             
                plt.ylabel('True Labels')
         
     | 
| 101 | 
         
             
                plt.title('Confusion Matrix')
         
     | 
| 102 | 
         
             
                plt.show()
         
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
                # Play animation again if needed
         
     | 
| 105 | 
         
            -
                # play_animation('path_to_your_animation.mp4')
         
     | 
| 
         | 
|
| 3 | 
         
             
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 4 | 
         
             
            from torchvision import transforms
         
     | 
| 5 | 
         
             
            from transformers import ViTForImageClassification
         
     | 
| 
         | 
|
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
             
            import pandas as pd
         
     | 
| 8 | 
         
             
            from sklearn.model_selection import train_test_split
         
     | 
| 9 | 
         
             
            from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
         
     | 
| 10 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 11 | 
         
             
            import seaborn as sns
         
     | 
| 
         | 
|
| 
         | 
|
| 12 | 
         | 
| 13 | 
         
            +
            # 驻讜谞拽爪讬讛 诇讛爪讙转 住专讟讜谉
         
     | 
| 14 | 
         
            +
            def display_video(video_url):
         
     | 
| 15 | 
         
            +
                video_html = f'''
         
     | 
| 16 | 
         
            +
                <iframe width="560" height="315" src="{video_url}" frameborder="0" allowfullscreen></iframe>
         
     | 
| 17 | 
         
            +
                '''
         
     | 
| 18 | 
         
            +
                # 讛谞讞 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
         
     | 
| 19 | 
         
            +
                return video_html
         
     | 
| 20 | 
         | 
| 21 | 
         
             
            def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
         
     | 
| 22 | 
         
             
                shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
         
     | 
| 23 | 
         
             
                train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
         
     | 
| 24 | 
         
             
                return train_df, val_df    
         
     | 
| 25 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 26 | 
         
             
            if __name__ == "__main__":
         
     | 
| 27 | 
         
             
                # Check for GPU availability
         
     | 
| 28 | 
         
             
                device = torch.device('cuda')
         
     | 
| 
         | 
|
| 54 | 
         
             
                true_labels = []
         
     | 
| 55 | 
         
             
                predicted_labels = []
         
     | 
| 56 | 
         | 
| 57 | 
         
            +
                # 拽讬砖讜专 诇住专讟讜谉
         
     | 
| 58 | 
         
            +
                video_url = 'https://youtube.com/shorts/vGRq060nPYU?feature=share'  # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
         
     | 
| 59 | 
         
            +
                video_html = display_video(video_url)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                # 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
         
     | 
| 62 | 
         
            +
                print(video_html)  # 讛爪讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
         
     | 
| 63 | 
         | 
| 64 | 
         
             
                with torch.no_grad():
         
     | 
| 65 | 
         
             
                    for images, labels in test_loader:
         
     | 
| 
         | 
|
| 91 | 
         
             
                plt.ylabel('True Labels')
         
     | 
| 92 | 
         
             
                plt.title('Confusion Matrix')
         
     | 
| 93 | 
         
             
                plt.show()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         |