Spaces:
Sleeping
Sleeping
import warnings | |
from datetime import datetime | |
from typing import Any, Optional | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import QueryResponse | |
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range | |
from article_embedding.embed import StellaEmbedder | |
from article_embedding.utils import env_str | |
warnings.simplefilter(action="ignore", category=FutureWarning) | |
def as_timestamp(date: datetime | str) -> float: | |
if isinstance(date, datetime): | |
return date.timestamp() | |
return datetime.strptime(date, "%Y-%m-%d").timestamp() | |
def make_date_condition( | |
*, field: str = "published", date_from: datetime | str | None = None, date_to: datetime | str | None = None | |
) -> FieldCondition | None: | |
kwargs = {} | |
if date_from: | |
kwargs["gte"] = as_timestamp(date_from) | |
if date_to: | |
kwargs["lt"] = as_timestamp(date_to) | |
if kwargs: | |
return FieldCondition(key=field, range=Range(**kwargs)) | |
return None | |
def make_topic_condition(topic_id: str) -> FieldCondition: | |
return FieldCondition(key="topics[]", match=MatchValue(value=topic_id)) | |
class Query: | |
_instance: Optional["Query"] = None | |
_embedding_model_instance: Optional[StellaEmbedder] = None | |
def __init__(self, index: str = "wsws", client: QdrantClient | None = None) -> None: | |
self.embedding_model = Query.embedding_model_singleton() | |
self.qdrant = QdrantClient(env_str("QDRANT_URL")) if client is None else client | |
self.index = index | |
def embedding_model_singleton() -> StellaEmbedder: | |
if Query._embedding_model_instance is None: | |
Query._embedding_model_instance = StellaEmbedder() | |
return Query._embedding_model_instance | |
def singleton() -> "Query": | |
if Query._instance is None: | |
Query._instance = Query() | |
return Query._instance | |
def embed(self, query: str) -> Any: | |
return self.embedding_model.embed([query])[0] | |
def query( | |
self, | |
query: str, | |
query_filter: Filter | None = None, | |
limit: int = 10, | |
) -> QueryResponse: | |
vector = self.embedding_model.embed([query])[0] | |
return self.qdrant.query_points(self.index, query=vector, query_filter=query_filter, limit=limit) | |
if __name__ == "__main__": | |
import gspread | |
from dotenv import load_dotenv | |
from gspread.utils import ValueInputOption | |
data = [ | |
("2021-01-01", "2021-05-01", "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"), | |
( | |
"2021-05-01", | |
"2021-09-01", | |
"The COVID vaccine rollout, Biden declaring independence from COVID while the Delta wave continues", | |
), | |
( | |
"2021-09-01", | |
"2022-01-01", | |
"The emergence of the COVID Omicron variant and the embrace of herd immunity by the ruling class", | |
), | |
] | |
load_dotenv() | |
query = Query() | |
rows: list[list[str]] = [] | |
for date_from, date_to, sentence in data: | |
result = query.query( | |
sentence, | |
query_filter=Filter(should=make_date_condition(date_from=date_from, date_to=date_to)), | |
) | |
rows.append([sentence]) | |
for point in result.points: | |
doc = point.payload | |
assert doc is not None | |
print(f'{point.score * 100:.1f}% https://www.wsws.org{doc["path"]} - {doc["title"]}') | |
rows.append( | |
[ | |
f"{point.score * 100:.1f}%", | |
datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"), | |
", ".join(doc["authors"]), | |
f'=hyperlink("https://www.wsws.org{doc["path"]}", "{doc["title"]}")', | |
] | |
) | |
rows.append([]) | |
gc = gspread.auth.oauth(credentials_filename=env_str("GOOGLE_CREDENTIALS")) | |
sh = gc.open("COVID-19 Compilation") | |
ws = sh.get_worksheet(0) | |
ws.append_rows(rows, value_input_option=ValueInputOption.user_entered) | |