import heapq
from typing import List, Dict, Iterable, Optional
import numpy as np
from needlestack.apis import indices_pb2
from needlestack.apis import collections_pb2
from needlestack.collections.shard import Shard
from needlestack.exceptions import DimensionMismatchException
[docs]class Collection(object):
"""A logical collection made of shards where kNN queries can be performed.
Attributes:
name: Name of collection
shards: Dictionary of shard names to shards
replication_factor: Number of replicas per shard in the cluster
enable_id_to_vector: Enable retrieving vector from id
dimension: Dimensionality of the vectors
"""
name: str
shards: Dict[str, Shard]
replication_factor: int
enable_id_to_vector: bool
dimension: int
[docs] @classmethod
def from_proto(cls, proto: collections_pb2.Collection) -> "Collection":
collection = cls()
collection.populate_from_proto(proto)
return collection
[docs] def populate_from_proto(self, proto: collections_pb2.Collection):
self.name = proto.name
self.replication_factor = proto.replication_factor
self.enable_id_to_vector = proto.enable_id_to_vector
self.shards = {}
shards = [Shard.from_proto(shard_proto) for shard_proto in proto.shards]
for shard in shards:
shard.enable_id_to_vector = self.enable_id_to_vector
self.add_shard(shard)
[docs] def merge_proto(self, proto):
self.replication_factor = proto.replication_factor
self.enable_id_to_vector = proto.enable_id_to_vector
[docs] def load(self):
for shard in self.shards.values():
shard.enable_id_to_vector = self.enable_id_to_vector
shard.load()
self.validate()
[docs] def update_available(self) -> bool:
for shard in self.shards.values():
if shard.update_available():
return True
return False
[docs] def validate(self):
shard_dimensions = {shard.index.dimension for shard in self.shards.values()}
if len(shard_dimensions) > 1:
raise DimensionMismatchException(
f"All shards in {self.name} Collection do not match dimensions"
)
self.dimension = shard_dimensions.pop()
[docs] def add_shard(self, shard: Shard):
self.shards[shard.name] = shard
[docs] def drop_shard(self, name: str):
del self.shards[name]
[docs] def query(
self, X: np.ndarray, k: int, shard_names: List[str]
) -> Iterable[indices_pb2.SearchResultItem]:
shard_results = [
self.shards[shard_name].query(X, k) for shard_name in shard_names
]
return heapq.merge(
*shard_results, key=lambda x: x.float_distance or x.double_distance
)
[docs] def retrieve(
self, id: str, shard_names: List[str]
) -> Optional[indices_pb2.RetrievalResultItem]:
for shard_name in shard_names:
retrieval_item = self.shards[shard_name].retrieve(id)
if retrieval_item is not None:
return retrieval_item
return None