audit_assistant / utils.py
akryldigital's picture
Pilot (#2)
92633a7 verified
raw
history blame
4.85 kB
import json
import dataclasses
from uuid import UUID
from typing import Any
from datetime import datetime, date
import configparser
from torch import cuda
from qdrant_client.http import models as rest
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
def get_config(fp):
config = configparser.ConfigParser()
config.read_file(open(fp))
return config
def get_embeddings_model(config):
device = "cuda" if cuda.is_available() else "cpu"
# Define embedding model
model_name = config.get("retriever", "MODEL")
model_kwargs = {"device": device}
normalize_embeddings = bool(int(config.get("retriever", "NORMALIZE")))
encode_kwargs = {
"normalize_embeddings": normalize_embeddings,
"batch_size": 100,
}
embeddings = HuggingFaceEmbeddings(
show_progress=True,
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embeddings
# Create a search filter for Qdrant
def create_filter(
reports: list = [], sources: str = None, subtype: str = None, year: str = None
):
if len(reports) == 0:
print(f"defining filter for sources:{sources}, subtype:{subtype}")
filter = rest.Filter(
must=[
rest.FieldCondition(
key="metadata.source", match=rest.MatchValue(value=sources)
),
rest.FieldCondition(
key="metadata.filename", match=rest.MatchAny(any=subtype)
),
# rest.FieldCondition(
# key="metadata.year",
# match=rest.MatchAny(any=year)
]
)
else:
print(f"defining filter for allreports:{reports}")
filter = rest.Filter(
must=[
rest.FieldCondition(
key="metadata.filename", match=rest.MatchAny(any=reports)
)
]
)
return filter
def load_json(fp):
with open(fp, "r") as f:
docs = json.load(f)
return docs
def get_timestamp():
now = datetime.datetime.now()
timestamp = now.strftime("%Y%m%d%H%M%S")
return timestamp
# A custom class to help with recursive serialization.
# This approach avoids modifying the original object.
class _RecursiveSerializer(json.JSONEncoder):
"""A custom JSONEncoder that handles complex types by converting them to dicts or strings."""
def default(self, obj):
# Prefer the pydantic method if it exists for the most robust serialization.
if hasattr(obj, 'model_dump'):
return obj.model_dump()
# Handle dataclasses
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
# Handle other non-serializable but common types.
if isinstance(obj, (datetime, date, UUID)):
return str(obj)
# Fallback for general objects with a __dict__
if hasattr(obj, '__dict__'):
return obj.__dict__
# Default fallback to JSONEncoder's behavior
return super().default(obj)
def to_json_string(obj: Any, **kwargs) -> str:
"""
Serializes a Python object into a JSON-formatted string.
This function is a comprehensive utility that can handle:
- Standard Python types (lists, dicts, strings, numbers, bools, None).
- Pydantic models (using `model_dump()`).
- Dataclasses (using `dataclasses.asdict()`).
- Standard library types not natively JSON-serializable (e.g., datetime, UUID).
- Custom classes with a `__dict__`.
Args:
obj (Any): The Python object to serialize.
**kwargs: Additional keyword arguments to pass to `json.dumps`.
Returns:
str: A JSON-formatted string.
Example:
>>> from datetime import datetime
>>> from pydantic import BaseModel
>>> from dataclasses import dataclass
>>> class Address(BaseModel):
... street: str
... city: str
>>> @dataclass
... class Product:
... id: int
... name: str
>>> class Order(BaseModel):
... user_address: Address
... item: Product
>>> order_obj = Order(
... user_address=Address(street="123 Main St", city="Example City"),
... item=Product(id=1, name="Laptop")
... )
>>> print(to_json_string(order_obj, indent=2))
{
"user_address": {
"street": "123 Main St",
"city": "Example City"
},
"item": {
"id": 1,
"name": "Laptop"
}
}
"""
return json.dumps(obj, cls=_RecursiveSerializer, **kwargs)