Spaces:
				
			
			
	
			
			
					
		Running
		
			on 
			
			CPU Upgrade
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
			on 
			
			CPU Upgrade
	Commit 
							
							·
						
						3408aae
	
1
								Parent(s):
							
							5d6ca81
								
rename
Browse files
    	
        main.py
    CHANGED
    
    | 
         @@ -9,9 +9,13 @@ from httpx import AsyncClient 
     | 
|
| 9 | 
         
             
            from huggingface_hub import DatasetCard
         
     | 
| 10 | 
         
             
            from pydantic import BaseModel
         
     | 
| 11 | 
         
             
            from starlette.responses import RedirectResponse
         
     | 
| 12 | 
         
            -
            from starlette.status import  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
            -
            from  
     | 
| 15 | 
         | 
| 16 | 
         
             
            # Set up logging
         
     | 
| 17 | 
         
             
            logging.basicConfig(
         
     | 
| 
         @@ -97,6 +101,14 @@ class DatasetCardNotFoundError(HTTPException): 
     | 
|
| 97 | 
         
             
                    )
         
     | 
| 98 | 
         | 
| 99 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 100 | 
         
             
            @app.get("/similar", response_model=QueryResponse)
         
     | 
| 101 | 
         
             
            @cache(ttl="1h")
         
     | 
| 102 | 
         
             
            async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
         
     | 
| 
         @@ -115,7 +127,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le 
     | 
|
| 115 | 
         
             
                            collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
         
     | 
| 116 | 
         
             
                            logger.info(f"Dataset {dataset_id} added to collection")
         
     | 
| 117 | 
         
             
                            result = collection.get(ids=[dataset_id], include=["embeddings"])
         
     | 
| 118 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 119 | 
         
             
                            raise
         
     | 
| 120 | 
         
             
                        except Exception as e:
         
     | 
| 121 | 
         
             
                            logger.error(
         
     | 
| 
         @@ -157,6 +171,44 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le 
     | 
|
| 157 | 
         
             
                    ) from e
         
     | 
| 158 | 
         | 
| 159 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 160 | 
         
             
            if __name__ == "__main__":
         
     | 
| 161 | 
         
             
                import uvicorn
         
     | 
| 162 | 
         | 
| 
         | 
|
| 9 | 
         
             
            from huggingface_hub import DatasetCard
         
     | 
| 10 | 
         
             
            from pydantic import BaseModel
         
     | 
| 11 | 
         
             
            from starlette.responses import RedirectResponse
         
     | 
| 12 | 
         
            +
            from starlette.status import (
         
     | 
| 13 | 
         
            +
                HTTP_404_NOT_FOUND,
         
     | 
| 14 | 
         
            +
                HTTP_500_INTERNAL_SERVER_ERROR,
         
     | 
| 15 | 
         
            +
                HTTP_403_FORBIDDEN,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         | 
| 18 | 
         
            +
            from load_card_data import get_embedding_function, get_save_path, refresh_data
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            # Set up logging
         
     | 
| 21 | 
         
             
            logging.basicConfig(
         
     | 
| 
         | 
|
| 101 | 
         
             
                    )
         
     | 
| 102 | 
         | 
| 103 | 
         | 
| 104 | 
         
            +
            class DatasetNotForAllAudiencesError(HTTPException):
         
     | 
| 105 | 
         
            +
                def __init__(self, dataset_id: str):
         
     | 
| 106 | 
         
            +
                    super().__init__(
         
     | 
| 107 | 
         
            +
                        status_code=HTTP_403_FORBIDDEN,
         
     | 
| 108 | 
         
            +
                        detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
             
            @app.get("/similar", response_model=QueryResponse)
         
     | 
| 113 | 
         
             
            @cache(ttl="1h")
         
     | 
| 114 | 
         
             
            async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
         
     | 
| 
         | 
|
| 127 | 
         
             
                            collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
         
     | 
| 128 | 
         
             
                            logger.info(f"Dataset {dataset_id} added to collection")
         
     | 
| 129 | 
         
             
                            result = collection.get(ids=[dataset_id], include=["embeddings"])
         
     | 
| 130 | 
         
            +
                            if result.get("not-for-all-audiences"):
         
     | 
| 131 | 
         
            +
                                raise DatasetNotForAllAudiencesError(dataset_id)
         
     | 
| 132 | 
         
            +
                        except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
         
     | 
| 133 | 
         
             
                            raise
         
     | 
| 134 | 
         
             
                        except Exception as e:
         
     | 
| 135 | 
         
             
                            logger.error(
         
     | 
| 
         | 
|
| 171 | 
         
             
                    ) from e
         
     | 
| 172 | 
         | 
| 173 | 
         | 
| 174 | 
         
            +
            @app.post("/similar_by_text", response_model=QueryResponse)
         
     | 
| 175 | 
         
            +
            @cache(ttl="1h")
         
     | 
| 176 | 
         
            +
            async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
         
     | 
| 177 | 
         
            +
                try:
         
     | 
| 178 | 
         
            +
                    logger.info(f"Querying datasets by text: {query}")
         
     | 
| 179 | 
         
            +
                    collection = client.get_collection(
         
     | 
| 180 | 
         
            +
                        name="dataset_cards", embedding_function=get_embedding_function()
         
     | 
| 181 | 
         
            +
                    )
         
     | 
| 182 | 
         
            +
                    print(query)
         
     | 
| 183 | 
         
            +
                    query_result = collection.query(
         
     | 
| 184 | 
         
            +
                        query_texts=query, n_results=n, include=["distances"]
         
     | 
| 185 | 
         
            +
                    )
         
     | 
| 186 | 
         
            +
                    print(query_result)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    if not query_result["ids"]:
         
     | 
| 189 | 
         
            +
                        logger.info(f"No similar datasets found for query: {query}")
         
     | 
| 190 | 
         
            +
                        raise HTTPException(
         
     | 
| 191 | 
         
            +
                            status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
         
     | 
| 192 | 
         
            +
                        )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # Prepare the response
         
     | 
| 195 | 
         
            +
                    results = [
         
     | 
| 196 | 
         
            +
                        QueryResult(dataset_id=str(id), similarity=float(1 - distance))
         
     | 
| 197 | 
         
            +
                        for id, distance in zip(
         
     | 
| 198 | 
         
            +
                            query_result["ids"][0], query_result["distances"][0]
         
     | 
| 199 | 
         
            +
                        )
         
     | 
| 200 | 
         
            +
                    ]
         
     | 
| 201 | 
         
            +
                    logger.info(f"Found {len(results)} similar datasets for query: {query}")
         
     | 
| 202 | 
         
            +
                    return QueryResponse(results=results)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                except Exception as e:
         
     | 
| 205 | 
         
            +
                    logger.error(f"Error querying datasets by text {query}: {str(e)}")
         
     | 
| 206 | 
         
            +
                    raise HTTPException(
         
     | 
| 207 | 
         
            +
                        status_code=HTTP_500_INTERNAL_SERVER_ERROR,
         
     | 
| 208 | 
         
            +
                        detail="An unexpected error occurred.",
         
     | 
| 209 | 
         
            +
                    ) from e
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
             
            if __name__ == "__main__":
         
     | 
| 213 | 
         
             
                import uvicorn
         
     | 
| 214 | 
         |