Source code for needlestack.servicers.searcher

import logging
from typing import Dict

import grpc

from needlestack.apis import collections_pb2
from needlestack.apis import servicers_pb2
from needlestack.apis import servicers_pb2_grpc
from needlestack.apis import serializers
from needlestack.collections.collection import Collection
from needlestack.collections.shard import Shard
from needlestack.cluster_managers import ClusterManager
from needlestack.servicers.settings import BaseConfig
from needlestack.utilities.rpc import unhandled_exception_rpc


logger = logging.getLogger("needlestack")


[docs]class SearcherServicer(servicers_pb2_grpc.SearcherServicer): """A gRPC servicer to perform kNN queries on in-memory index structures""" collections: Dict[str, Collection] collection_protos: Dict[str, collections_pb2.Collection] def __init__(self, config: BaseConfig, cluster_manager: ClusterManager): self.config = config self.cluster_manager = cluster_manager self.collections = {} self.collection_protos = {} self.cluster_manager.register_searcher() self.load_collections()
[docs] @unhandled_exception_rpc(servicers_pb2.SearchResponse) def Search(self, request, context): X = serializers.proto_to_ndarray(request.vector) k = request.count collection = self.get_collection(request.collection_name) if len(X.shape) == 1: X = X.reshape(1, -1) if collection.dimension == X.shape[1]: results = collection.query(X, k, request.shard_names) items = [item for i, item in enumerate(results) if i < k] return servicers_pb2.SearchResponse(items=items) else: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details( f"Collection {collection.name} expected matrix shaped ({collection.dimension}), got {X.shape}" ) return servicers_pb2.SearchResponse()
[docs] @unhandled_exception_rpc(servicers_pb2.RetrieveResponse) def Retrieve(self, request, context): collection = self.get_collection(request.collection_name) item = collection.retrieve(request.id, request.shard_names) if item is not None: return servicers_pb2.RetrieveResponse(item=item) else: return servicers_pb2.RetrieveResponse()
[docs] @unhandled_exception_rpc(collections_pb2.CollectionsLoadResponse) def CollectionsLoad(self, request, context): self.load_collections() return collections_pb2.CollectionsLoadResponse()
[docs] def get_collection(self, name: str) -> Collection: return self.collections[name]
[docs] def load_collections(self): """Load collections from Zookeeper configs There are 4 states to handle for each collection: - A new collection needs to be loaded - An existing collection needs to be dropped - An existing collection added/dropped shards - No changes """ collection_protos = self.cluster_manager.list_local_collections( include_state=False ) current_collections = {name for name in self.collection_protos.keys()} new_collections = {proto.name for proto in collection_protos} for proto in collection_protos: if proto.name in current_collections: self._modify_collection(proto) else: self._add_collection(proto) for name in current_collections: if name not in new_collections: self._drop_collection(name) for collection in self.collections.values(): if collection.update_available(): logger.debug(f"Update collection {collection.name}") self.cluster_manager.set_local_state( collections_pb2.Replica.BOOTING, collection.name ) collection.load() self.cluster_manager.set_local_state( collections_pb2.Replica.ACTIVE, collection.name ) self.collection_protos = {proto.name: proto for proto in collection_protos}
def _add_collection(self, proto: collections_pb2.Collection): logger.debug(f"Add collection {proto.name}") collection = Collection.from_proto(proto) self.cluster_manager.set_local_state( collections_pb2.Replica.BOOTING, collection.name ) self.collections[collection.name] = collection collection.load() self.cluster_manager.set_local_state( collections_pb2.Replica.ACTIVE, collection.name ) def _drop_collection(self, name: str): logger.debug(f"Drop collection {name}") del self.collections[name] def _modify_collection(self, proto: collections_pb2.Collection): old_proto = self.collection_protos[proto.name] if old_proto.SerializeToString() != proto.SerializeToString(): collection = self.get_collection(proto.name) collection.merge_proto(proto) old_shards = {shard.name: shard for shard in old_proto.shards} new_shards = {shard.name: shard for shard in proto.shards} for name, new_shard in new_shards.items(): if name not in old_shards: logger.debug(f"Add collection shard {proto.name}/{name}") self.cluster_manager.set_local_state( collections_pb2.Replica.BOOTING, collection.name, name ) collection.add_shard(Shard.from_proto(new_shard)) elif ( new_shard.SerializeToString() != old_shards[name].SerializeToString() ): logger.debug(f"Update collection shard {proto.name}/{name}") self.cluster_manager.set_local_state( collections_pb2.Replica.BOOTING, collection.name, name ) collection.drop_shard(name) collection.add_shard(Shard.from_proto(new_shard)) for name in old_shards.keys(): if name not in new_shards: logger.debug(f"Drop collection shard {proto.name}/{name}") collection.drop_shard(name) collection.load() self.cluster_manager.set_local_state( collections_pb2.Replica.ACTIVE, collection.name, name )