Princess3 commited on
Commit
8417cb6
1 Parent(s): 3504cb5

Update m5.py

Browse files
Files changed (1) hide show
  1. m5.py +14 -2
m5.py CHANGED
@@ -4,9 +4,21 @@ from collections import defaultdict
4
  from accelerate import Accelerator
5
  from transformers import AutoTokenizer, AutoModel
6
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Set the TRANSFORMERS_CACHE environment variable to a writable directory
9
- os.environ['TRANSFORMERS_CACHE'] = '/app/cache'
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
 
12
  class DM(nn.Module):
 
4
  from accelerate import Accelerator
5
  from transformers import AutoTokenizer, AutoModel
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ import termcolor
8
+
9
+ # Set the cache directory path
10
+ cache_dir = '/app/cache'
11
+
12
+ # Create the directory if it doesn't exist
13
+ if not os.path.exists(cache_dir):
14
+ os.makedirs(cache_dir)
15
+
16
+ # Set the environment variable
17
+ os.environ['TRANSFORMERS_CACHE'] = cache_dir
18
+
19
+ # Verify the environment variable is set
20
+ print(f"TRANSFORMERS_CACHE is set to: {os.environ['TRANSFORMERS_CACHE']}")
21
 
 
 
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
 
24
  class DM(nn.Module):