SushantGautam commited on
Commit
eac2634
1 Parent(s): 0f5ba98

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +2 -106
script.py CHANGED
@@ -1,107 +1,3 @@
1
- import subprocess
2
  import sys
3
- import os, shutil
4
- import os, zipfile
5
- from torchvision import transforms
6
- from torchvision.datasets import ImageFolder
7
- from torchmetrics.image.fid import FrechetInceptionDistance
8
- from torch.utils.data import DataLoader
9
- from PIL import Image
10
- import torch
11
- from tqdm import tqdm
12
- import pandas as pd
13
-
14
- print("/tmp/data: ", ", ".join(os.listdir("/tmp/data")) if os.listdir("/tmp/data") else "The directory /tmp/data is empty.")
15
- print("/tmp/: ",", ".join(os.listdir("/tmp/")) if os.listdir("/tmp/") else "The directory /tmp/ is empty.")
16
- print("/tmp/model: ",", ".join(os.listdir("/tmp/model")) if os.listdir("/tmp/model") else "The directory /tmp/model is empty.")
17
-
18
- # print("/tmp/model/script.py: ", open('/tmp/model/script.py').read())
19
- print("/tmp/model/params.json: ", open('/tmp/model/params.json').read())
20
-
21
-
22
-
23
- os.makedirs('/tmp/data/hub/checkpoints/', exist_ok=True)
24
- os.environ['TORCH_HOME'] = '/tmp/data'
25
- shutil.move('/tmp/data/weights-inception-2015-12-05-6726825d.pth', '/tmp/data/hub/checkpoints/')
26
-
27
- print(f"Unzipping single")
28
- with zipfile.ZipFile("/tmp/data/real-images-single.zip", 'r') as zip_ref: zip_ref.extractall("/tmp/data")
29
- print(f"Unzipping multi")
30
- with zipfile.ZipFile("/tmp/data/real-images-multi.zip", 'r') as zip_ref: zip_ref.extractall("/tmp/data")
31
-
32
- # Directories with images
33
- fdir2 = "/tmp/model/generated" # Directory with fake images
34
-
35
- # Real data directories
36
-
37
- real_dirs = {
38
- "Single": "/tmp/data/real-images-single",
39
- "Multiple": "/tmp/data/real-images-multi",
40
- "Both": "/tmp/data/real-images-both"
41
- }
42
- ## generate real-images-both directory
43
- print("Preparing test dataset. . .")
44
- target_directory= '/tmp/data/real-images-both/images'
45
- os.makedirs(target_directory, exist_ok=True)
46
-
47
- for directory in ["/tmp/data/real-images-single", "/tmp/data/real-images-multi"]:
48
- print("for :"+ directory)
49
- images_dir = os.path.join(directory, "images")
50
- os.makedirs(images_dir, exist_ok=True)
51
- subprocess.run(f"mv {directory}/*.* {images_dir}/", shell=True, check=True)
52
- subprocess.run(f"ln {images_dir}/*.* {target_directory}/", shell=True, check=True)
53
-
54
- print("link generated images")
55
- # Organize directories for generated images
56
- dest_generated = "/tmp/model/generated/images"
57
- os.makedirs(dest_generated, exist_ok=True)
58
- subprocess.run(f"ln {fdir2}/*.* {dest_generated}/", shell=True, check=True)
59
-
60
-
61
- print("Init FID")
62
- # Initialize FID metric
63
- fid = FrechetInceptionDistance(feature=2048)
64
-
65
- # Define image transformations
66
- transform = transforms.Compose([
67
- transforms.Resize((299, 299)),
68
- transforms.ToTensor(),
69
- ])
70
-
71
- # Prepare fake dataset and loader
72
- generated_dataset = ImageFolder(root=fdir2, transform=transform)
73
- generated_loader = DataLoader(generated_dataset, batch_size=50, shuffle=False)
74
-
75
- # Compute FID scores for each real data directory
76
- fid_scores = {}
77
- for key, fdir1 in real_dirs.items():
78
- print("for set: ", fdir1)
79
-
80
- # Organize real images
81
- dest_real = os.path.join(fdir1, "images")
82
-
83
- # Create dataset and loader for real images
84
- real_dataset = ImageFolder(root=fdir1, transform=transform)
85
- real_loader = DataLoader(real_dataset, batch_size=50, shuffle=False)
86
-
87
- # Reset FID metric for current folder
88
- fid.reset()
89
-
90
- # Process real images
91
- for images, _ in tqdm(real_loader, desc=f"Processing real images: {key}"):
92
- images = (images * 255).to(torch.uint8)
93
- fid.update(images, real=True)
94
-
95
- # Process generated images
96
- for images, _ in tqdm(generated_loader, desc=f"Processing fake images: {key}"):
97
- images = (images * 255).to(torch.uint8)
98
- fid.update(images, real=False)
99
-
100
- # Compute FID score
101
- fid_score = fid.compute()
102
- fid_scores[f"FID_Score_{key}"] = fid_score.item()
103
- print(f"FID Score ({key}):", fid_score.item())
104
-
105
- # Save FID scores to CSV
106
- df = pd.DataFrame([fid_scores])
107
- df.to_csv("submission.csv", index=False)
 
 
1
  import sys
2
+ sys.path.insert(0, "/tmp/data")
3
+ import script