Boni98 commited on
Commit
3af8d8e
·
1 Parent(s): 5480fb2

Delete clip_chat.py

Browse files
Files changed (1) hide show
  1. clip_chat.py +0 -81
clip_chat.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import clip
3
- from PIL import Image
4
- import glob
5
- import os
6
- from random import choice
7
-
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- # model, preprocess = clip.load("ViT-L/14@336px", device=device)
11
- model, preprocess = clip.load(clip.available_models()[-3], device=device)
12
- COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
13
-
14
-
15
- def load_random_image():
16
- image_path = choice(COCO)
17
- image = Image.open(image_path)
18
- return image
19
-
20
-
21
- def next_image():
22
- global image_org, image
23
- image_org = load_random_image()
24
- image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
25
-
26
-
27
- def calculate_logits(image_features, text_features):
28
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
29
- text_features = text_features / text_features.norm(dim=1, keepdim=True)
30
-
31
- logit_scale = model.logit_scale.exp()
32
- return logit_scale * image_features @ text_features.t()
33
-
34
-
35
- last = -1
36
- best = -1
37
-
38
- goal = 30
39
-
40
- image_org = load_random_image()
41
- image = preprocess(image_org).unsqueeze(0).to(device)
42
- with torch.no_grad():
43
- image_features = model.encode_image(image)
44
-
45
-
46
- def answer(message):
47
- global last, best
48
-
49
- text = clip.tokenize([message]).to(device)
50
-
51
- with torch.no_grad():
52
- text_features = model.encode_text(text)
53
- logits_per_image, _ = model(image, text)
54
- logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
55
-
56
- if last == -1:
57
- is_better = -1
58
- elif last > logits:
59
- is_better = 0
60
- elif last < logits:
61
- is_better = 1
62
- elif logits > goal:
63
- is_better = 2
64
- else:
65
- is_better = -1
66
-
67
- last = logits
68
- if logits > best:
69
- best = logits
70
- is_better = 3
71
-
72
- return logits, is_better
73
-
74
-
75
- def reset_everything():
76
- global last, best, goal, image, image_org
77
- last = -1
78
- best = -1
79
- goal = 21
80
- image_org = load_random_image()
81
- image = preprocess(image_org).unsqueeze(0).to(device)