Source code for ads.database.connection

#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import ast
import json
import oci
import os
import pathlib
import re
import shutil

from oci.exceptions import ServiceError
from tempfile import NamedTemporaryFile
from zipfile import ZipFile

from oci.secrets import SecretsClient
from oci.config import from_file

from ads.common.decorator.runtime_dependency import (
    runtime_dependency,
    OptionalDependency,
)
from ads.common import utils
from ads.common import auth as authutil
from ads.common import oci_client as oc
from ads.vault.vault import Vault


[docs] class Connector: def __init__( self, secret_id: str = None, key: str = None, repository_path: str = None, **kwargs, ): """ Validate that a connection could be made for the given set of connection parameters, and contruct a Connector object provided that the validation is successful. Parameters ---------- secret_id: str, optional The ocid of the secret to retrieve from Oracle Cloud Infrastructure Vault. key: str, optional The key to find the database directory. repository_path: str, optional The local database information store, default to ~/.database unless specified otherwise. kwargs: dict, optional Name-value pairs that are to be added to the list of connection parameters. For example, database_name="mydb", database_type="oracle", username = "root", password = "example-password". Returns ------- A Connector object. """ prio_dict = {} if kwargs: command_creds = {} for input_key in kwargs.keys(): command_creds[input_key] = kwargs.get(input_key) # declaring priority order prio_dict[1] = command_creds # get creds content from vault if secret id is provided if secret_id: if not bool(re.match("^ocid[0-9]?\.vaultsecret.*", secret_id)): raise ValueError(f"{secret_id} is not a valid secret id.") auth = authutil.default_signer() self.secret_client = oc.OCIClientFactory(**auth).secret secret_bundle = self.secret_client.get_secret_bundle(secret_id) secret_content = ast.literal_eval( Vault._secret_to_dict(secret_bundle.data.secret_bundle_content.content) ) # declaring priority order prio_dict[2] = secret_content # get creds content from local if key is provided repository_path = _get_repository_path(repository_path=repository_path) if not os.path.exists(repository_path): raise ValueError(f"{repository_path} does not exist.") if key: if _not_valid_key(key=key): raise ValueError(f"{key} is not a valid directory name.") db_path = _get_db_path(repository_path=repository_path, key=key) if not os.path.exists(db_path): raise ValueError(f"{db_path} does not exist.") local_content = get_repository(key=key, repository_path=repository_path) # declaring priority order prio_dict[3] = local_content # Combine dictionary with priority using ** operator config = {} for k in reversed(sorted(prio_dict.keys())): config.update(**prio_dict[k]) self.config = config # check database types valid_database_types = ["oracle"] if "database_type" not in self.config: raise ValueError( f"The database_type needs to be specified. " f"Valid database types are {valid_database_types}" ) if self.config["database_type"] not in valid_database_types: raise ValueError( f"{self.config['database_type']} is not a valid database type. " f"Valid database types are {valid_database_types}" ) if self.config["database_type"] == "oracle": connector = OracleConnector(self.config) self.uri = connector.uri self.engine = connector.engine
[docs] def connect(self): return self.engine.connect()
def __enter__(self): self.db_connection = self.engine.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): return self def __repr__(self): # pragma: no cover return str(self.engine.url)
[docs] class OracleConnector: @runtime_dependency(module="sqlalchemy", install_from=OptionalDependency.DATA) def __init__(self, oracle_connection_config): self.config = oracle_connection_config # sanity check on valid keys before making a connection valid_keys = ["username", "password", "database_name"] for vk in valid_keys: if vk not in self.config.keys(): raise ValueError(f"{vk} is a required parameter to connect.") self.uri = _create_connection_str(config=self.config) self.engine = sqlalchemy.create_engine(self.uri)
[docs] def update_repository( value: dict, key: str, replace: bool = True, repository_path: str = None ) -> dict: """ Saves value into local database store. Parameters ---------- value: dict The values to store locally. key: str The key to find the local database directory. replace: bool, default to True If set to false, updates the stored value. repository_path:str: str, optional The local database store, default to ~/.database unless specified otherwise. Returns ------- A dictionary of all values in the repository for the given key. """ if _not_valid_key(key=key): raise ValueError(f"{key} is not a valid directory name.") # make local database directory repository_path = _get_repository_path(repository_path=repository_path) pathlib.Path(repository_path).mkdir(parents=True, exist_ok=True) db_path = _get_db_path(repository_path=repository_path, key=key) pathlib.Path(db_path).mkdir(parents=True, exist_ok=True) db_config_path = os.path.join(db_path, "config.json") if not replace: value = _update(value, db_config_path) with open(db_config_path, "w") as fp: json.dump(value, fp) return value
def _update(new_value, db_config_path): # update existing key's values if found different and add new key-value pairs with open(db_config_path) as f: old_value = json.load(f) prio_dict = {1: old_value, 2: new_value} return {**prio_dict[2], **prio_dict[1]}
[docs] def get_repository(key: str, repository_path: str = None) -> dict: """ Get all values from local database store. Parameters ---------- key: str The key to find the database directory. repository_path: str, optional The path to local database store, default to ~/.database unless specified otherwise. Returns ------- A dictionary of all values in the store. """ if _not_valid_key(key=key): raise ValueError(f"{key} is not a valid directory name.") # check whether repository_path exists repository_path = _get_repository_path(repository_path=repository_path) if not os.path.exists(repository_path): raise ValueError(f"{repository_path} does not exist.") # check whether db_path exists db_path = _get_db_path(repository_path=repository_path, key=key) if not os.path.exists(db_path): raise ValueError(f"{db_path} does not exist.") db_config_path = os.path.join(db_path, "config.json") with open(db_config_path) as f: return json.load(f)
[docs] def import_wallet(wallet_path: str, key: str, repository_path: str = None) -> None: """ Saves wallet to local database store. Unzip the wallet zip file, update sqlnet.ora and store wallet files. Parameters ---------- wallet_path: str The local path to the downloaded wallet zip file. key: str The key to find the database directory. repository_path: str, optional The local database store, default to ~/.database unless specified otherwise. """ if _not_valid_key(key=key): raise ValueError(f"{key} is not a valid directory name.") # checking paths are valid repository_path = _get_repository_path(repository_path=repository_path) db_path = _get_db_path(repository_path=repository_path, key=key) if not os.path.exists(db_path): raise ValueError(f"{db_path} does not exist.") if not os.path.exists(wallet_path): raise ValueError(f"{wallet_path} does not exist.") # Create a ZipFile Object and load wallet zip in it with ZipFile(wallet_path, "r") as zipObj: # Extract all the contents of zip file in current directory zipObj.extractall( os.path.join(os.path.expanduser("~"), f"{repository_path}/{key}") ) # Add TNS_ADMIN to the environment os.environ["TNS_ADMIN"] = db_path # Update the sqlnet.ora _update_sqlnet(db_path=db_path) # Update the config.json file so that the 'database_type' is set to oracle. db_config_path = os.path.join(db_path, "config.json") value = _update({"database_type": "oracle"}, db_config_path) with open(db_config_path, "w") as fp: json.dump(value, fp)
def _update_sqlnet(db_path): sqlnet_path = os.path.join(db_path, "sqlnet.ora") sqlnet_original_path = os.path.join(db_path, "sqlnet.ora.original") sqlnet_backup_path = os.path.join(db_path, "sqlnet.ora.backup") if not os.path.exists(sqlnet_original_path): shutil.copy(sqlnet_path, sqlnet_original_path) if os.path.exists(sqlnet_path): shutil.copy(sqlnet_path, sqlnet_backup_path) sqlnet_re = re.compile( '(WALLET_LOCATION\s*=.*METHOD_DATA\s*=.*DIRECTORY\s*=\s*")(.*)(".*)', re.IGNORECASE, ) tmp = NamedTemporaryFile() with open(sqlnet_path, "rt") as sqlnet: for line in sqlnet: tmp.write( bytearray( sqlnet_re.subn(r"\1{}\3".format(db_path), line)[0], encoding="utf-8" ) ) tmp.flush() shutil.copy(tmp.name, sqlnet_path) tmp.close() def _get_repository_path(repository_path): return ( os.path.join(os.path.expanduser("~"), ".database") if repository_path is None else repository_path ) def _get_db_path(repository_path, key): return os.path.join(repository_path, key) def _not_valid_key(key): return bool(re.search("[../.]", key)) def _create_connection_str(config): return ( "oracle+cx_oracle://" + config["username"] + ":" + config["password"] + "@" + config["database_name"] )