Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	add multi-class
Browse files- app.py +28 -5
 - output.png +0 -0
 - requirements.txt +2 -1
 
    	
        app.py
    CHANGED
    
    | 
         @@ -17,10 +17,31 @@ nltk.download('averaged_perceptron_tagger') 
     | 
|
| 17 | 
         
             
            from nltk.tokenize import word_tokenize
         
     | 
| 18 | 
         
             
            import torchvision
         
     | 
| 19 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         
             
            args = default_argument_parser().parse_args()
         
     | 
| 21 | 
         
             
            cfg = setup(args)
         
     | 
| 22 | 
         | 
| 23 | 
         
            -
            multi_classes =  
     | 
| 24 | 
         | 
| 25 | 
         
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 26 | 
         
             
            Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
         
     | 
| 
         @@ -42,10 +63,12 @@ def run(sketch, caption, threshold, seed): 
     | 
|
| 42 | 
         | 
| 43 | 
         
             
                # set the condidate classes here
         
     | 
| 44 | 
         
             
                caption = caption.replace('\n',' ')
         
     | 
| 45 | 
         
            -
                 
     | 
| 46 | 
         
            -
                 
     | 
| 47 | 
         
            -
                 
     | 
| 48 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 49 | 
         
             
                if len(classes) ==0 or multi_classes == False:
         
     | 
| 50 | 
         
             
                    classes = [caption]
         
     | 
| 51 | 
         | 
| 
         | 
|
| 17 | 
         
             
            from nltk.tokenize import word_tokenize
         
     | 
| 18 | 
         
             
            import torchvision
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
            import spacy
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # download the model
         
     | 
| 23 | 
         
            +
            spacy.cli.download("en_core_web_sm")
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # Load spaCy model
         
     | 
| 26 | 
         
            +
            nlp = spacy.load("en_core_web_sm")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def extract_objects(prompt):
         
     | 
| 29 | 
         
            +
                doc = nlp(prompt)
         
     | 
| 30 | 
         
            +
                # Extract object nouns (including proper nouns and compound nouns)
         
     | 
| 31 | 
         
            +
                objects = set()
         
     | 
| 32 | 
         
            +
                for token in doc:
         
     | 
| 33 | 
         
            +
                    # Check if the token is a noun or part of a named entity
         
     | 
| 34 | 
         
            +
                    if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_:
         
     | 
| 35 | 
         
            +
                        objects.add(token.text)
         
     | 
| 36 | 
         
            +
                    # Check if the token is part of a compound noun
         
     | 
| 37 | 
         
            +
                    if token.dep_ in {"compound"}:
         
     | 
| 38 | 
         
            +
                        objects.add(token.head.text)
         
     | 
| 39 | 
         
            +
                return list(objects)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
             
            args = default_argument_parser().parse_args()
         
     | 
| 42 | 
         
             
            cfg = setup(args)
         
     | 
| 43 | 
         | 
| 44 | 
         
            +
            multi_classes = True
         
     | 
| 45 | 
         | 
| 46 | 
         
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 47 | 
         
             
            Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
         
     | 
| 
         | 
|
| 63 | 
         | 
| 64 | 
         
             
                # set the condidate classes here
         
     | 
| 65 | 
         
             
                caption = caption.replace('\n',' ')
         
     | 
| 66 | 
         
            +
                classes = extract_objects(caption)
         
     | 
| 67 | 
         
            +
                # translator = str.maketrans('', '', string.punctuation)
         
     | 
| 68 | 
         
            +
                # caption = caption.translate(translator).lower()
         
     | 
| 69 | 
         
            +
                # words = word_tokenize(caption)
         
     | 
| 70 | 
         
            +
                # classes = get_noun_phrase(words)
         
     | 
| 71 | 
         
            +
                # print(classes)
         
     | 
| 72 | 
         
             
                if len(classes) ==0 or multi_classes == False:
         
     | 
| 73 | 
         
             
                    classes = [caption]
         
     | 
| 74 | 
         | 
    	
        output.png
    CHANGED
    
    
												 
											 | 
										
												 
									 | 
									
								
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -10,4 +10,5 @@ iopath 
     | 
|
| 10 | 
         
             
            ftfy
         
     | 
| 11 | 
         
             
            fvcore
         
     | 
| 12 | 
         
             
            regex
         
     | 
| 13 | 
         
            -
            nltk
         
     | 
| 
         | 
| 
         | 
|
| 10 | 
         
             
            ftfy
         
     | 
| 11 | 
         
             
            fvcore
         
     | 
| 12 | 
         
             
            regex
         
     | 
| 13 | 
         
            +
            nltk
         
     | 
| 14 | 
         
            +
            spacy
         
     |