Source code for ads.data_labeling.reader.dls_record_reader
#!/usr/bin/env python# -*- coding: utf-8; -*-# Copyright (c) 2021, 2022 Oracle and/or its affiliates.# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/fromcollectionsimportdefaultdictfromdataclassesimportdataclassfromtypingimportAny,Generator,Listfromads.commonimportauthasauthutilfromads.commonimportoci_clientfromads.data_labeling.interface.readerimportReaderfromoci.data_labeling_service_dataplane.modelsimportAnnotationSummary,RecordSummaryfromoci.exceptionsimportServiceErrorfromoci.paginationimportlist_call_get_all_results
[docs]classReadRecordsError(Exception):# pragma: no coverdef__init__(self,dataset_id:str):super().__init__(f"Error occurred in attempt to read records of dataset {dataset_id}.")
[docs]classReadAnnotationsError(Exception):# pragma: no coverdef__init__(self,dataset_id:str):super().__init__(f"Error occurred in attempt to read annotations of dataset {dataset_id}.")
[docs]@dataclassclassOCIRecordSummary:"""The class that representing the labeled record in ADS format. Attributes ---------- record: RecordSummary OCI RecordSummary. annotations: List[AnnotationSummary] List of OCI AnnotationSummary. """record:RecordSummary=Noneannotation:List[AnnotationSummary]=None
[docs]classDLSRecordReader(Reader):"""DLS Record Reader Class that reads records from the cloud into ADS format."""def__init__(self,dataset_id:str,compartment_id:str,auth:dict=None):"""Initiates a DLSRecordReader instance. Parameters ---------- dataset_id: str The dataset OCID. compartment_id: str The compartment OCID of the dataset. auth: (dict, optional). Defaults to None. The default authetication is set using `ads.set_auth` API. If you need to override the default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate authentication signer and kwargs required to instantiate IdentityClient object. """ifnotdataset_id:raiseValueError("The dataset OCID must be specified.")ifnotisinstance(dataset_id,str):raiseTypeError("The dataset_id must be a string.")ifnotcompartment_id:raiseValueError("The compartment OCID must be specified.")ifnotisinstance(compartment_id,str):raiseTypeError("The compartment OCID must be a string.")auth=authorauthutil.default_signer()self.dataset_id=dataset_idself.compartment_id=compartment_idself.dls_dp_client=oci_client.OCIClientFactory(**auth).data_labeling_dpdef_read_records(self):try:returnlist_call_get_all_results(self.dls_dp_client.list_records,self.compartment_id,self.dataset_id,lifecycle_state="ACTIVE",).dataexceptServiceError:raiseReadRecordsError(self.dataset_id)def_read_annotations(self):try:returnlist_call_get_all_results(self.dls_dp_client.list_annotations,self.compartment_id,self.dataset_id,lifecycle_state="ACTIVE",).dataexceptServiceError:raiseReadAnnotationsError(self.dataset_id)
[docs]defread(self)->Generator[OCIRecordSummary,Any,Any]:"""Reads OCI records. Yields ------ OCIRecordSummary The OCIRecordSummary instance. """records=self._read_records()annotations=self._read_annotations()annotations_map=defaultdict(list)forannotationinannotations:annotations_map[annotation.record_id].append(annotation)forrecordinrecords:yieldOCIRecordSummary(record,annotations_map.get(record.id))