HawkeyeHS commited on
Commit
b98fea5
1 Parent(s): 918f252
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -18,6 +18,25 @@ cors = CORS(app)
18
  model = AutoModelForImageClassification.from_pretrained('carbon225/vit-base-patch16-224-hentai')
19
  feature_extractor = AutoFeatureExtractor.from_pretrained('carbon225/vit-base-patch16-224-hentai')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @app.route("/", methods=["GET"])
22
  def default():
23
  return json.dumps({"Server": "Working"})
@@ -28,18 +47,23 @@ def extract_images():
28
  src=request.args.get("src")
29
  response = requests.get(src)
30
  soup = BeautifulSoup(response.content,'html.parser')
31
- img_urls=[]
32
 
33
  img_tags = soup.select('div img')
34
  for img_tag in img_tags:
35
  img_url = urljoin(src, img_tag['src'])
36
- img_urls.append(img_url)
37
- return json.dumps({"images":img_urls})
 
 
 
 
 
 
38
  except Exception as e:
39
  return e
40
 
41
  @app.route("/predict", methods=["GET"])
42
- def predict():
43
  try:
44
  src = request.args.get("src")
45
 
@@ -47,21 +71,7 @@ def predict():
47
  response = requests.get(src)
48
  response.raise_for_status()
49
 
50
- # Open and preprocess the image
51
- image = Image.open(BytesIO(response.content))
52
- image = image.resize((128, 128))
53
-
54
- # Extract features using the pre-trained feature extractor
55
- encoding = feature_extractor(images=image.convert("RGB"), return_tensors="pt")
56
-
57
- # Make a prediction using the pre-trained model
58
- with torch.no_grad():
59
- outputs = model(**encoding)
60
- logits = outputs.logits
61
-
62
- # Get the predicted class index and label
63
- predicted_class_idx = logits.argmax(-1).item()
64
- predicted_class_label = model.config.id2label[predicted_class_idx]
65
 
66
  # Return the predictions
67
  return json.dumps({"class": predicted_class_label})
 
18
  model = AutoModelForImageClassification.from_pretrained('carbon225/vit-base-patch16-224-hentai')
19
  feature_extractor = AutoFeatureExtractor.from_pretrained('carbon225/vit-base-patch16-224-hentai')
20
 
21
+ def predict(response):
22
+ # Open and preprocess the image
23
+ image = Image.open(BytesIO(response.content))
24
+ image = image.resize((128, 128))
25
+
26
+ # Extract features using the pre-trained feature extractor
27
+ encoding = feature_extractor(images=image.convert("RGB"), return_tensors="pt")
28
+
29
+ # Make a prediction using the pre-trained model
30
+ with torch.no_grad():
31
+ outputs = model(**encoding)
32
+ logits = outputs.logits
33
+
34
+ # Get the predicted class index and label
35
+ predicted_class_idx = logits.argmax(-1).item()
36
+ predicted_class_label = model.config.id2label[predicted_class_idx]
37
+
38
+ return predicted_class_label
39
+
40
  @app.route("/", methods=["GET"])
41
  def default():
42
  return json.dumps({"Server": "Working"})
 
47
  src=request.args.get("src")
48
  response = requests.get(src)
49
  soup = BeautifulSoup(response.content,'html.parser')
 
50
 
51
  img_tags = soup.select('div img')
52
  for img_tag in img_tags:
53
  img_url = urljoin(src, img_tag['src'])
54
+ response = requests.get(img_url)
55
+ response.raise_for_status()
56
+ predicted_class_label = predict(response)
57
+
58
+ if predicted_class_label=='explicit' or predicted_class_label=='suggestive':
59
+ return json.dumps({"class":predicted_class_label})
60
+
61
+ return json.dumps({"class":"safe"})
62
  except Exception as e:
63
  return e
64
 
65
  @app.route("/predict", methods=["GET"])
66
+ def predict_image():
67
  try:
68
  src = request.args.get("src")
69
 
 
71
  response = requests.get(src)
72
  response.raise_for_status()
73
 
74
+ predicted_class_label = predict(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Return the predictions
77
  return json.dumps({"class": predicted_class_label})