geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from enum import Enum
from ..multithreading import Counters
from ..pyclass import autoclass, cast, JPaths
logger = logging.getLogger(__name__)
JFileSegment = autoclass('io.anserini.collection.FileSegment')
JSourceDocument = autoclass('io.anserini.collection.SourceDocument')
class JCollections(Enum):
AclAnthology = autoclass('io.anserini.collection.AclAnthology')
CarCollection = autoclass('io.anserini.collection.CarCollection')
Cord19AbstractCollection = autoclass('io.anserini.collection.Cord19AbstractCollection')
ClueWeb09Collection = autoclass('io.anserini.collection.ClueWeb09Collection')
ClueWeb12Collection = autoclass('io.anserini.collection.ClueWeb12Collection')
HtmlCollection = autoclass('io.anserini.collection.HtmlCollection')
JsonCollection = autoclass('io.anserini.collection.JsonCollection')
NewYorkTimesCollection = autoclass('io.anserini.collection.NewYorkTimesCollection')
TrecCollection = autoclass('io.anserini.collection.TrecCollection')
TrecwebCollection = autoclass('io.anserini.collection.TrecwebCollection')
TweetCollection = autoclass('io.anserini.collection.TweetCollection')
WashingtonPostCollection = autoclass('io.anserini.collection.WashingtonPostCollection')
WikipediaCollection = autoclass('io.anserini.collection.WikipediaCollection')
class Collection:
"""
Iterable wrapper class for Anserini's DocumentCollection.
Parameters
----------
collection_class : str
Name of collection class to instantiate
collection_path : str
Path to directory containing collection
"""
def __init__(self, collection_class, collection_path):
self.counters = Counters()
self.collection_class = collection_class
self.collection_path = JPaths.get(collection_path)
self.object = self._get_collection()
self.collection_iterator = self.object.iterator()
def _get_collection(self):
try:
return JCollections[self.collection_class].value(self.collection_path)
except:
raise ValueError(self.collection_class)
def __iter__(self):
return self
def __next__(self):
if self.collection_iterator.hasNext():
fs = self.collection_iterator.next()
return FileSegment(self, fs, fs.getSegmentPath())
else:
raise StopIteration
class FileSegment:
"""
Iterable wrapper class for Anserini's FileSegment.
Parameters
----------
collection : Collection
Parent collection of the file segment
segment : JFileSegment
FileSegment object to create wrapper from
segment_path : str
Path to file backing the file segment
"""
def __init__(self, collection, segment, segment_path):
self.collection = collection
try:
self.object = cast(collection.object.getClass().getName() +
'$Segment', segment)
except:
logger.exception('Exception from casting FileSegment type...')
self.object = cast('io.anserini.collection.FileSegment', segment)
self.segment_iterator = self.object.iterator()
self.segment_path = segment_path
self.segment_name = re.sub(r'\\|\/', '-', collection.collection_path.relativize(segment_path).toString())
def __iter__(self):
return self
def __next__(self):
if self.object.iterator().hasNext():
d = self.object.iterator().next()
return SourceDocument(self, d)
else:
# log if iteration stopped by error
if self.object.getErrorStatus():
logger.error(self.segment_name + ': Error from segment iteration, stopping...')
self.collection.counters.errors.increment()
# stop iteration and log skipped documents
skipped = self.object.getSkippedCount()
if skipped > 0:
self.collection.counters.skips.increment(skipped)
logger.warning(self.segment_name + ': ' + str(skipped) + ' documents skipped')
self.object.close()
raise StopIteration
class SourceDocument:
"""
Wrapper class for Anserini's SourceDocument.
Parameters
----------
segment : FileSegment
Parent segment of the source document
document : io.anserini.collection.SourceDocument
SourceDocument object to create wrapper from
"""
def __init__(self, segment, document):
if not isinstance(document, JSourceDocument):
raise TypeError('Invalid JSourceDocument!')
self.segment = segment
self.object = document
self.id = self.object.id()
self.indexable = self.object.indexable()
self.contents = self.object.contents()
self.raw = self.object.raw()