VINAYAK MODI commited on
Commit
68e22b1
1 Parent(s): 4d3693e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -38
app.py CHANGED
@@ -1,47 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- import torch
3
- from torchvision.transforms import transforms
4
  from PIL import Image
5
- from transformers import AutoModelForSequenceClassification
6
- # Load the model and tokenizer
7
- model_name = "vm24bho/net_dfm_myimg"
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
9
 
 
10
 
11
- # Define transformations for the input image
12
- transform = transforms.Compose([
13
- transforms.Resize((224, 224)),
14
- transforms.ToTensor(),
15
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
- ])
 
 
 
 
 
17
 
18
- def predict(image):
19
  # Preprocess the image
20
- image = transform(image).unsqueeze(0) # Add batch dimension
21
 
22
  # Perform inference
23
- outputs = model(image)
 
24
 
25
- # Get prediction
26
- prediction = torch.argmax(outputs.logits).item()
27
-
28
- return prediction
29
-
30
- def main():
31
- st.title("Image Detection: Real or Deepfake")
32
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
33
-
34
- if uploaded_image is not None:
35
- image = Image.open(uploaded_image)
36
- st.image(image, caption='Uploaded Image', use_column_width=True)
37
-
38
- # Make prediction
39
- if st.button("Detect"):
40
- prediction = predict(image)
41
- if prediction == 0:
42
- st.write("Prediction: Real")
43
- else:
44
- st.write("Prediction: Deepfake")
45
-
46
- if __name__ == "__main__":
47
- main()
 
1
+ # import streamlit as st
2
+ # import torch
3
+ # from torchvision.transforms import transforms
4
+ # from PIL import Image
5
+ # from transformers import AutoModelForSequenceClassification
6
+ # # Load the model and tokenizer
7
+ # model_name = "vm24bho/net_dfm_myimg"
8
+ # model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+
10
+
11
+ # # Define transformations for the input image
12
+ # transform = transforms.Compose([
13
+ # transforms.Resize((224, 224)),
14
+ # transforms.ToTensor(),
15
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
+ # ])
17
+
18
+ # def predict(image):
19
+ # # Preprocess the image
20
+ # image = transform(image).unsqueeze(0) # Add batch dimension
21
+
22
+ # # Perform inference
23
+ # outputs = model(image)
24
+
25
+ # # Get prediction
26
+ # prediction = torch.argmax(outputs.logits).item()
27
+
28
+ # return prediction
29
+
30
+ # def main():
31
+ # st.title("Image Detection: Real or Deepfake")
32
+ # uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
33
+
34
+ # if uploaded_image is not None:
35
+ # image = Image.open(uploaded_image)
36
+ # st.image(image, caption='Uploaded Image', use_column_width=True)
37
+
38
+ # # Make prediction
39
+ # if st.button("Detect"):
40
+ # prediction = predict(image)
41
+ # if prediction == 0:
42
+ # st.write("Prediction: Real")
43
+ # else:
44
+ # st.write("Prediction: Deepfake")
45
+
46
+ # if __name__ == "__main__":
47
+ # main()
48
+
49
+
50
  import streamlit as st
51
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
 
52
  from PIL import Image
53
+ import torch
54
+
55
+ # Load the model and feature extractor
56
+ model = ViTForImageClassification.from_pretrained("path/to/your/model")
57
+ feature_extractor = ViTFeatureExtractor.from_pretrained("path/to/your/model")
58
 
59
+ st.title("Deepfake Classification App")
60
 
61
+ # File uploader
62
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
63
+
64
+ if uploaded_file is not None:
65
+ # Load the image
66
+ image = Image.open(uploaded_file)
67
+
68
+ # Display the image
69
+ st.image(image, caption='Uploaded Image', use_column_width=True)
70
+ st.write("")
71
+ st.write("Classifying...")
72
 
 
73
  # Preprocess the image
74
+ inputs = feature_extractor(images=image, return_tensors="pt")
75
 
76
  # Perform inference
77
+ with torch.no_grad():
78
+ outputs = model(**inputs)
79
 
80
+ # Get the predicted label
81
+ logits = outputs.logits
82
+ predicted_class_idx = logits.argmax(-1).item()
83
+ predicted_class_label = model.config.id2label[predicted_class_idx]
84
+
85
+ # Display the result
86
+ st.write(f"Prediction: {predicted_class_label}")