jervinjosh68 commited on
Commit
9eadd04
1 Parent(s): 2794d4e

added files

Browse files
Files changed (2) hide show
  1. app.py +18 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,7 +5,16 @@ import torchvision.transforms as T
5
  from PIL import Image
6
  import numpy as np
7
  import gradio as gr
8
- model = AQC_NET(pretrain=True,num_label=5)
 
 
 
 
 
 
 
 
 
9
  def predict(image_name):
10
  model.eval()
11
 
@@ -48,5 +57,12 @@ def run_gradio():
48
  theme="huggingface",
49
  ).launch(debug=True, enable_queue=True)
50
 
51
- #print(predict("test_image.jpg"))
 
 
 
 
 
 
 
52
  run_gradio()
 
5
  from PIL import Image
6
  import numpy as np
7
  import gradio as gr
8
+ import requests
9
+ import os
10
+
11
+ def get_file(url,path,filename, chunk_size=128):
12
+ r = requests.get(url, stream=True)
13
+ with open(path, 'wb') as downloaded:
14
+ for chunk in r.iter_content(chunk_size=chunk_size):
15
+ downloaded.write(chunk)
16
+
17
+
18
  def predict(image_name):
19
  model.eval()
20
 
 
57
  theme="huggingface",
58
  ).launch(debug=True, enable_queue=True)
59
 
60
+ model = AQC_NET(pretrain=True,num_label=5)
61
+ if not os.path.exists('weight.pth'):
62
+ print("weight.pth does not exist. Downloading...")
63
+ get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
64
+ print("weight.pth downloaded")
65
+ else:
66
+ print('Specified file (weight.pth) already downloaded. Skipping this step.')
67
+ torch.load("weight.pth")
68
  run_gradio()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  torchvision
3
  pillow
4
  numpy
5
- gradio
 
 
2
  torchvision
3
  pillow
4
  numpy
5
+ gradio
6
+ requests