Source code for needlestack.cluster_managers.zookeeper

import logging
import signal
from copy import deepcopy
from typing import List, Optional

import kazoo
from kazoo.client import KazooClient, KazooState
from kazoo.recipe.cache import TreeCache
from kazoo.retry import KazooRetry

from needlestack.apis import collections_pb2
from needlestack.cluster_managers import ClusterManager


logger = logging.getLogger("needlestack")


[docs]class ZookeeperClusterManager(ClusterManager): """ A cluster manager that manages one cluster's state and configurations with a Zookeeper ensemble via kazoo. Below is the structure of the znodes: /needlestack /<CLUSTER_NAME_1> /live_nodes /<HOSTPORT_1> /<HOSTPORT_2> /<HOSTPORT_3> /<HOSTPORT_4> ... /collections /<COLLECTION_NAME_1> /shards /<SHARD_NAME_1> /replicas /<HOSTPORT_2> /<HOSTPORT_4> /<SHARD_NAME_2> /replicas /<HOSTPORT_1> /<HOSTPORT_3> /<COLLECTION_NAME_2> ... """ cluster_name: str hostport: str zk: KazooClient cache: TreeCache def __init__( self, cluster_name: str, hostport: str, hosts: List[str], zookeeper_root: str ): self.cluster_name = cluster_name self.hostport = hostport self.zookeeper_root = zookeeper_root self.zk = KazooClient(hosts=hosts) self.zk.add_listener(self.zk_listener) self.cache = TreeCache(self.zk, self.base_znode) @property def base_znode(self): return f"{self.zookeeper_root}/{self.cluster_name}" @property def live_nodes_znode(self): return f"{self.base_znode}/live_nodes" @property def this_node_znode(self): return f"{self.base_znode}/live_nodes/{self.hostport}" @property def collections_znode(self): return f"{self.base_znode}/collections"
[docs] def collection_znode(self, collection_name: str) -> str: return f"{self.collections_znode}/{collection_name}"
[docs] def shard_znode(self, collection_name: str, shard_name: str = None) -> str: znode = f"{self.collections_znode}/{collection_name}/shards" if shard_name: znode += "/" + shard_name return znode
[docs] def replica_znode( self, collection_name: str, shard_name: str, hostport: str = None ) -> str: shard_znode = self.shard_znode(collection_name, shard_name) znode = f"{shard_znode}/replicas" if hostport: znode += "/" + hostport return znode
[docs] def startup(self): self.zk.start() self.cache.start() signal.signal(signal.SIGINT, self.signal_listener) signal.signal(signal.SIGTERM, self.signal_listener) self.zk.ensure_path(self.live_nodes_znode) self.zk.ensure_path(self.collections_znode)
[docs] def shutdown(self): self.cache.close() self.zk.stop() self.zk.close()
[docs] def cleanup(self): logger.info(f"Removing ZNodes via cleanup") transaction = self.zk.transaction() for collection in self.list_local_collections(): for shard in collection.shards: for replica in shard.replicas: znode = self.replica_znode( collection.name, shard.name, replica.hostport ) transaction.delete(znode) self.commit_transaction(transaction)
[docs] def register_merger(self): pass
[docs] def register_searcher(self): try: retrier = KazooRetry(max_tries=5, delay=1, backoff=2, max_delay=20) retrier(self.zk.create, self.this_node_znode, ephemeral=True, makepath=True) logger.info(f"Created ephemeral ZNode {self.this_node_znode}") except kazoo.retry.RetryFailedError: logger.error( f"Max retries reached for creating ephemeral ZNode {self.this_node_znode}" ) except kazoo.retry.InterruptedError: logger.error( f"Retries interrupted for creating ephemeral ZNode {self.this_node_znode}" )
[docs] def set_state(self, state, collection_name=None, shard_name=None, hostport=None): transaction = self.zk.transaction() collections = [collection_name] if collection_name else None for collection in self._list_collections( collections, hostport=hostport, load_replica=True ): logger.info( f"Set {collection.name}/shards ZNodes to {collections_pb2.Replica.State.Name(state)}" ) for shard in collection.shards: for replica in shard.replicas: znode = self.replica_znode( collection.name, shard.name, replica.node.hostport ) replica.state = state transaction.set_data(znode, replica.SerializeToString()) return self.commit_transaction(transaction)
[docs] def set_local_state(self, state, collection_name=None, shard_name=None): return self.set_state(state, collection_name, shard_name, self.hostport)
[docs] def signal_listener(self, signum, frame): self.shutdown()
[docs] def zk_listener(self, state): if state == KazooState.LOST: logger.warn("Connection to Zookeeper lost") elif state == KazooState.SUSPENDED: logger.warn("Connection to Zookeeper disconnected") else: logger.info("Connection to Zookeeper established")
[docs] def add_collections(self, collections): """Configure a list of collections into Zookeeper """ transaction = self.zk.transaction() for collection in collections: collection_copy = deepcopy(collection) collection_copy.ClearField("shards") collection_znode = self.collection_znode(collection.name) transaction.create(collection_znode, collection_copy.SerializeToString()) transaction.create(self.shard_znode(collection.name)) for shard in collection.shards: shard_copy = deepcopy(shard) shard_copy.ClearField("replicas") shard_znode = self.shard_znode(collection.name, shard.name) transaction.create(shard_znode, shard_copy.SerializeToString()) transaction.create(self.replica_znode(collection.name, shard.name)) for replica in shard.replicas: replica_copy = deepcopy(replica) replica_copy.state = collections_pb2.Replica.BOOTING replica_znode = self.replica_znode( collection.name, shard.name, replica.node.hostport ) transaction.create(replica_znode, replica_copy.SerializeToString()) if self.commit_transaction(transaction): return collections else: return []
[docs] def delete_collections(self, collection_names): transaction = self.zk.transaction() for collection_name in collection_names: shards_znode = self.shard_znode(collection_name) for shard_name in self.zk.get_children(shards_znode): replicas_znode = self.replica_znode(collection_name, shard_name) for replica_name in self.zk.get_children(replicas_znode): replica_znode = self.replica_znode( collection_name, shard_name, replica_name ) transaction.delete(replica_znode) transaction.delete(replicas_znode) transaction.delete(self.shard_znode(collection_name, shard_name)) transaction.delete(shards_znode) transaction.delete(self.collection_znode(collection_name)) if self.commit_transaction(transaction): return collection_names else: return []
[docs] def list_nodes(self): live_nodes = self.zk.get_children(self.live_nodes_znode) nodes = [collections_pb2.Node(hostport=node) for node in live_nodes] return nodes
[docs] def list_collections(self, collection_names=None, include_state=True): return self._list_collections(collection_names, load_replica=include_state)
[docs] def list_local_collections(self, include_state=True): return self._list_collections( hostport=self.hostport, load_replica=include_state )
def _list_collections( self, collection_names: Optional[List[str]] = None, hostport: Optional[str] = None, load_replica: Optional[bool] = True, ) -> List[collections_pb2.Collection]: collections = [] collection_names = collection_names or self.zk.get_children( self.collections_znode ) for collection_name in collection_names: shards = [] shards_znode = self.shard_znode(collection_name) for shard_name in self.zk.get_children(shards_znode): replicas = [] replicas_znode = self.replica_znode(collection_name, shard_name) for replica_hostport in self.zk.get_children(replicas_znode): if hostport == replica_hostport or hostport is None: replica_znode = self.replica_znode( collection_name, shard_name, replica_hostport ) if load_replica: replica_data, _ = self.zk.get(replica_znode) replica_proto = collections_pb2.Replica.FromString( replica_data ) else: replica_proto = collections_pb2.Replica() replicas.append(replica_proto) if replicas: shard_znode = self.shard_znode(collection_name, shard_name) shard_data, _ = self.zk.get(shard_znode) shard_proto = collections_pb2.Shard.FromString(shard_data) shard_proto.replicas.extend(replicas) shards.append(shard_proto) if shards: collection_znode = self.collection_znode(collection_name) collection_data, _ = self.zk.get(collection_znode) collection_proto = collections_pb2.Collection.FromString( collection_data ) collection_proto.shards.extend(shards) collections.append(collection_proto) return collections
[docs] def get_searchers(self, collection_name, shard_names=None): if not shard_names: shards_znode = self.shard_znode(collection_name) shard_names = self.cache.get_children(shards_znode, []) shard_hostports = [] for shard_name in shard_names: hostports = self._get_searchers_for_shard( collection_name, shard_name, active=True ) if hostports: shard_hostports.append((shard_name, hostports)) else: logger.error( f"No active Searcher node for {collection_name}/{shard_name}." ) return shard_hostports
def _get_searchers_for_shard( self, collection_name: str, shard_name: str, active: bool = True ) -> List[str]: replicas_znode = self.replica_znode(collection_name, shard_name) hostports = self.cache.get_children(replicas_znode, []) if active: active_hostports = [] for hostport in hostports: replica_znode = self.replica_znode( collection_name, shard_name, hostport ) node = self.cache.get_data(replica_znode) if node: replica = collections_pb2.Replica.FromString(node.data) if replica.state == collections_pb2.Replica.ACTIVE: active_hostports.append(hostport) hostports = active_hostports return hostports
[docs] def commit_transaction(self, transaction: kazoo.client.TransactionRequest) -> bool: """Commit a transaction and log the first exception after rollbacks""" for result, operation in zip(transaction.commit(), transaction.operations): if isinstance(result, kazoo.exceptions.RolledBackError): continue elif isinstance(result, Exception): logger.error( f"{result.__class__.__name__} in Kazoo transaction: {operation}" ) return False return True