paul hilders commited on
Commit
cca85c2
1 Parent(s): d80767e

Import spacy model

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  import CLIP.clip as clip
9
 
10
  import spacy
 
11
 
12
 
13
  from clip_grounding.utils.image import pad_to_square
@@ -24,7 +25,7 @@ clip.clip._MODELS = {
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
26
 
27
- NER = spacy.load("en_core_web_sm")
28
 
29
  # Gradio Section:
30
  def run_demo(image, text):
@@ -43,6 +44,16 @@ def run_demo(image, text):
43
  for i, token in enumerate(text_tokens_decoded):
44
  highlighted_text.append((str(token), float(text_scores[i])))
45
 
 
 
 
 
 
 
 
 
 
 
46
  return overlapped, highlighted_text
47
 
48
  input_img = gr.inputs.Image(type='pil', label="Original Image")
 
8
  import CLIP.clip as clip
9
 
10
  import spacy
11
+ from spacy import displacy
12
 
13
 
14
  from clip_grounding.utils.image import pad_to_square
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
27
 
28
+ nlp = spacy.load("en_core_web_sm")
29
 
30
  # Gradio Section:
31
  def run_demo(image, text):
 
44
  for i, token in enumerate(text_tokens_decoded):
45
  highlighted_text.append((str(token), float(text_scores[i])))
46
 
47
+ # Apply NER to extract named entities, and run the explainability method
48
+ # for each named entity.
49
+ highlighed_entities = []
50
+ for ent in nlp(text).ents:
51
+ ent_text = ent.text
52
+ ent_label = ent.label_
53
+ highlighed_entities.append((ent_text, ent_label))
54
+
55
+ print(highlighed_entities)
56
+
57
  return overlapped, highlighted_text
58
 
59
  input_img = gr.inputs.Image(type='pil', label="Original Image")