Source code for ads.data_labeling.reader.dls_record_reader

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Generator, List

from ads.common import auth as authutil
from ads.common import oci_client
from ads.data_labeling.interface.reader import Reader
from oci.data_labeling_service_dataplane.models import AnnotationSummary, RecordSummary
from oci.exceptions import ServiceError
from oci.pagination import list_call_get_all_results


[docs]class ReadRecordsError(Exception): # pragma: no cover def __init__(self, dataset_id: str): super().__init__( f"Error occurred in attempt to read records of dataset {dataset_id}." )
[docs]class ReadAnnotationsError(Exception): # pragma: no cover def __init__(self, dataset_id: str): super().__init__( f"Error occurred in attempt to read annotations of dataset {dataset_id}." )
[docs]@dataclass class OCIRecordSummary: """The class that representing the labeled record in ADS format. Attributes ---------- record: RecordSummary OCI RecordSummary. annotations: List[AnnotationSummary] List of OCI AnnotationSummary. """ record: RecordSummary = None annotation: List[AnnotationSummary] = None
[docs]class DLSRecordReader(Reader): """DLS Record Reader Class that reads records from the cloud into ADS format.""" def __init__(self, dataset_id: str, compartment_id: str, auth: dict = None): """Initiates a DLSRecordReader instance. Parameters ---------- dataset_id: str The dataset OCID. compartment_id: str The compartment OCID of the dataset. auth: (dict, optional). Defaults to None. The default authetication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. """ if not dataset_id: raise ValueError("The dataset OCID must be specified.") if not isinstance(dataset_id, str): raise TypeError("The dataset_id must be a string.") if not compartment_id: raise ValueError("The compartment OCID must be specified.") if not isinstance(compartment_id, str): raise TypeError("The compartment OCID must be a string.") auth = auth or authutil.default_signer() self.dataset_id = dataset_id self.compartment_id = compartment_id self.dls_dp_client = oci_client.OCIClientFactory(**auth).data_labeling_dp def _read_records(self): try: return list_call_get_all_results( self.dls_dp_client.list_records, self.compartment_id, self.dataset_id, lifecycle_state="ACTIVE", ).data except ServiceError: raise ReadRecordsError(self.dataset_id) def _read_annotations(self): try: return list_call_get_all_results( self.dls_dp_client.list_annotations, self.compartment_id, self.dataset_id, lifecycle_state="ACTIVE", ).data except ServiceError: raise ReadAnnotationsError(self.dataset_id)
[docs] def read(self) -> Generator[OCIRecordSummary, Any, Any]: """Reads OCI records. Yields ------ OCIRecordSummary The OCIRecordSummary instance. """ records = self._read_records() annotations = self._read_annotations() annotations_map = defaultdict(list) for annotation in annotations: annotations_map[annotation.record_id].append(annotation) for record in records: yield OCIRecordSummary(record, annotations_map.get(record.id))