Spaces:
Runtime error
Runtime error
| from typing import Optional, Tuple | |
| import sqlalchemy | |
| from pgvector.sqlalchemy import Vector | |
| from sqlalchemy.dialects.postgresql import JSON, UUID | |
| from sqlalchemy.orm import Session, relationship | |
| from langchain.vectorstores.pgvector import BaseModel | |
| class CollectionStore(BaseModel): | |
| """Collection store.""" | |
| __tablename__ = "langchain_pg_collection" | |
| name = sqlalchemy.Column(sqlalchemy.String) | |
| cmetadata = sqlalchemy.Column(JSON) | |
| embeddings = relationship( | |
| "EmbeddingStore", | |
| back_populates="collection", | |
| passive_deletes=True, | |
| ) | |
| def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: | |
| return session.query(cls).filter(cls.name == name).first() # type: ignore | |
| def get_or_create( | |
| cls, | |
| session: Session, | |
| name: str, | |
| cmetadata: Optional[dict] = None, | |
| ) -> Tuple["CollectionStore", bool]: | |
| """ | |
| Get or create a collection. | |
| Returns [Collection, bool] where the bool is True if the collection was created. | |
| """ | |
| created = False | |
| collection = cls.get_by_name(session, name) | |
| if collection: | |
| return collection, created | |
| collection = cls(name=name, cmetadata=cmetadata) | |
| session.add(collection) | |
| session.commit() | |
| created = True | |
| return collection, created | |
| class EmbeddingStore(BaseModel): | |
| """Embedding store.""" | |
| __tablename__ = "langchain_pg_embedding" | |
| collection_id = sqlalchemy.Column( | |
| UUID(as_uuid=True), | |
| sqlalchemy.ForeignKey( | |
| f"{CollectionStore.__tablename__}.uuid", | |
| ondelete="CASCADE", | |
| ), | |
| ) | |
| collection = relationship(CollectionStore, back_populates="embeddings") | |
| embedding: Vector = sqlalchemy.Column(Vector(None)) | |
| document = sqlalchemy.Column(sqlalchemy.String, nullable=True) | |
| cmetadata = sqlalchemy.Column(JSON, nullable=True) | |
| # custom_id : any user defined id | |
| custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) | |