import logging
import random
import heapq
from typing import List, Tuple, Dict
import grpc
from needlestack.apis import collections_pb2
from needlestack.apis import servicers_pb2
from needlestack.apis import servicers_pb2_grpc
from needlestack.balancers import calculate_add
from needlestack.balancers.greedy import GreedyAlgorithm
from needlestack.cluster_managers import ClusterManager
from needlestack.servicers.settings import BaseConfig
from needlestack.utilities.rpc import unhandled_exception_rpc
logger = logging.getLogger("needlestack")
[docs]class MergerServicer(servicers_pb2_grpc.MergerServicer):
"""A gRPC servicer to accept external requests, use searcher nodes, and
merge results together.
"""
def __init__(self, config: BaseConfig, cluster_manager: ClusterManager):
self.config = config
self.cluster_manager = cluster_manager
self.cluster_manager.register_merger()
self.ssl_channel_credentials = self.config.ssl_channel_credentials
[docs] @unhandled_exception_rpc(servicers_pb2.SearchResponse)
def Search(self, request, context):
hostports_shards = self.get_searcher_hostports(
request.collection_name, list(request.shard_names)
)
futures = []
for hostport, shard_names in hostports_shards:
stub = self.get_searcher_stub(hostport)
subrequest = servicers_pb2.SearchRequest(
vector=request.vector,
count=request.count,
collection_name=request.collection_name,
shard_names=shard_names,
)
future = stub.Search.future(subrequest)
futures.append(future)
subsearch_results = [future.result() for future in futures]
num_subsearch = len(subsearch_results)
if num_subsearch > 1:
item_batches = [result.items for result in subsearch_results]
merged_item_batches = heapq.merge(
*item_batches, key=lambda x: x.float_distance or x.double_distance
)
items = [
item for i, item in enumerate(merged_item_batches) if i < request.count
]
return servicers_pb2.SearchResponse(items=items)
elif num_subsearch == 1:
return subsearch_results[0]
else:
context.set_code(grpc.StatusCode.UNKNOWN)
context.set_details("Empty responses from Search")
return servicers_pb2.SearchResponse()
[docs] @unhandled_exception_rpc(servicers_pb2.RetrieveResponse)
def Retrieve(self, request, context):
hostports_shards = self.get_searcher_hostports(
request.collection_name, list(request.shard_names)
)
futures = []
for hostport, shard_names in hostports_shards:
stub = self.get_searcher_stub(hostport)
subrequest = servicers_pb2.RetrieveRequest(
id=request.id,
collection_name=request.collection_name,
shard_names=shard_names,
)
future = stub.Retrieve.future(subrequest)
futures.append(future)
for future in futures:
result = future.result()
if result.item.metadata.id:
return result
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("ID not found in collection")
return servicers_pb2.RetrieveResponse()
[docs] @unhandled_exception_rpc(collections_pb2.CollectionsAddResponse)
def CollectionsAdd(self, request, context):
new_collections = request.collections
current_collections = self.cluster_manager.list_collections()
new_names = {c.name for c in new_collections}
current_names = {c.name for c in current_collections}
if new_names & current_names:
context.set_code(grpc.StatusCode.ALREADY_EXISTS)
context.set_details(
f"Collections {new_names & current_names} already exists. No new collections added."
)
return collections_pb2.CollectionsAddResponse(success=False)
nodes = self.cluster_manager.list_nodes()
algorithm = GreedyAlgorithm()
collections_to_add = calculate_add(
nodes, current_collections, new_collections, algorithm
)
success = True
if not request.noop:
self.cluster_manager.add_collections(collections_to_add)
success = self.collections_load()
return collections_pb2.CollectionsAddResponse(
collections=collections_to_add, success=success
)
[docs] @unhandled_exception_rpc(collections_pb2.CollectionsDeleteResponse)
def CollectionsDelete(self, request, context):
collection_names = request.names
success = True
new_names = set(collection_names)
current_names = {
collection.name for collection in self.cluster_manager.list_collections()
}
if not new_names <= current_names:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(
f"Collections {new_names - current_names} do not exists"
)
return collections_pb2.CollectionsDeleteResponse()
if not request.noop:
self.cluster_manager.delete_collections(collection_names)
success = self.collections_load()
return collections_pb2.CollectionsDeleteResponse(
names=collection_names, success=success
)
[docs] @unhandled_exception_rpc(collections_pb2.CollectionsLoadResponse)
def CollectionsLoad(self, request, context):
success = self.collections_load()
return collections_pb2.CollectionsLoadResponse(success=success)
[docs] @unhandled_exception_rpc(collections_pb2.CollectionsListResponse)
def CollectionsList(self, request, context):
collection_names = list(request.names)
collections = self.cluster_manager.list_collections(collection_names)
return collections_pb2.CollectionsListResponse(collections=collections)
[docs] def collections_load(self) -> bool:
success = True
futures = []
nodes = self.cluster_manager.list_nodes()
for node in nodes:
stub = self.get_searcher_stub(node.hostport)
subrequest = collections_pb2.CollectionsLoadRequest()
future = stub.CollectionsLoad.future(subrequest)
futures.append((node.hostport, future))
for hostport, future in futures:
try:
result = future.result()
success = success and result.success
except grpc.RpcError as e:
success = False
logger.error(f"Searcher {hostport} failed CollectionsLoadRequest: {e}")
return success
[docs] def get_searcher_hostports(
self, collection_name: str, shard_names: List[str] = None
) -> List[Tuple[str, List[str]]]:
shard_hostports = self.cluster_manager.get_searchers(
collection_name, shard_names
)
shard_to_host = {}
for shard_name, hostports in shard_hostports:
shard_to_host[shard_name] = random.choice(hostports)
host_to_shards: Dict[str, List] = {}
for shard_name, hostport in shard_to_host.items():
host_to_shards[hostport] = host_to_shards.get(hostport, [])
host_to_shards[hostport].append(shard_name)
return list(host_to_shards.items())
[docs] def get_searcher_stub(self, hostport: str) -> servicers_pb2_grpc.SearcherStub:
if self.config.use_channel_ssl:
channel = grpc.secure_channel(hostport, self.ssl_channel_credentials)
else:
channel = grpc.insecure_channel(hostport)
return servicers_pb2_grpc.SearcherStub(channel)