Source code for needlestack.apis.serializers
from itertools import repeat
from typing import Any, Tuple, Optional, List, Union
import numpy as np
from needlestack.apis import tensors_pb2
from needlestack.apis import indices_pb2
from needlestack.exceptions import SerializationError, DeserializationError
TYPE_TO_ENUM = {
"float16": tensors_pb2.NDArray.FLOAT16,
"float32": tensors_pb2.NDArray.FLOAT32,
"float64": tensors_pb2.NDArray.FLOAT64,
"int8": tensors_pb2.NDArray.INT8,
"int16": tensors_pb2.NDArray.INT16,
"int32": tensors_pb2.NDArray.INT32,
"int64": tensors_pb2.NDArray.INT64,
}
ENUM_TO_TYPE = {v: k for k, v in TYPE_TO_ENUM.items()}
[docs]def ndarray_to_proto(
X: Any, dtype: Optional[str] = None, shape: Optional[Tuple] = None
) -> tensors_pb2.NDArray:
"""Transforms a Python n-dimension array into a protobuf
Args:
X: ND Array
dtype: Explicit datatype for number
shape: Explicit shape for nd array
"""
proto = tensors_pb2.NDArray()
if isinstance(X, list):
if dtype is None:
raise SerializationError("Serializing list needs dtype")
if shape is None:
raise SerializationError("Serializing list needs shape")
X = np.array(X, dtype=dtype)
if X.shape != shape:
raise SerializationError("Shape mismatch")
if isinstance(X, np.ndarray):
if dtype and X.dtype.name != dtype:
if dtype in TYPE_TO_ENUM:
X = X.astype(dtype)
else:
raise SerializationError(f"{dtype} dtype not supported")
dtype_enum = TYPE_TO_ENUM.get(X.dtype.name)
if dtype_enum is None:
raise SerializationError(f"{X.dtype.name} dtype not yet supported")
proto.dtype = dtype_enum
proto.shape.extend(X.shape)
proto.numpy_content = X.tobytes()
return proto
else:
raise SerializationError("Unsupported NDArray")
[docs]def proto_to_ndarray(proto: tensors_pb2.NDArray) -> np.ndarray:
"""Transform a protobuf into a numpy array
Args:
proto: Protobuf for nd array
"""
dtype = ENUM_TO_TYPE.get(proto.dtype)
if not proto.shape:
raise DeserializationError("Missing attribute shape to convert to ndarray")
if proto.numpy_content and dtype:
return np.frombuffer(proto.numpy_content, dtype=dtype).reshape(*proto.shape)
elif proto.float_val:
dtype = dtype or "float32"
return np.array(proto.float_val, dtype=dtype).reshape(*proto.shape)
elif proto.double_val:
dtype = dtype or "float64"
return np.array(proto.double_val, dtype=dtype).reshape(*proto.shape)
elif proto.int_val:
dtype = dtype or "int32"
return np.array(proto.int_val, dtype=dtype).reshape(*proto.shape)
elif proto.long_val:
dtype = dtype or "int64"
return np.array(proto.long_val, dtype=dtype).reshape(*proto.shape)
else:
raise DeserializationError("Missing value attribute to convert to ndarray")
[docs]def metadata_list_to_proto(
ids: List[str],
fields_list: List[Tuple],
fieldtypes: Optional[Tuple[str]] = None,
fieldnames: Optional[Tuple[str]] = None,
) -> List[indices_pb2.Metadata]:
"""Serialize a set of items with metadata fields
Args:
ids: List of ids for items
fields_list: List of tuple of field values
fieldtypes: Optional tuple of types for values
fieldname: Optional tuple of names for values
"""
return [
metadata_to_proto(id, fields, fieldtypes, fieldnames)
for id, fields in zip(ids, fields_list)
]
[docs]def metadata_to_proto(
id: str,
fields: Tuple,
fieldtypes: Optional[Tuple[str]] = None,
fieldnames: Optional[Tuple[str]] = None,
) -> indices_pb2.Metadata:
"""Serialize a set of metadata fields for some item.
Skips over None fields
Args:
id: ID for item
fields: Tuple of primative python values
fieldtypes: Optional tuple of types for values
fieldnames: Optional tuple of names for values
"""
_fieldtypes = fieldtypes or repeat(None, len(fields))
_fieldnames = fieldnames or repeat(None, len(fields))
metadata_fields = [
metadata_field_to_proto(field, fieldtype, fieldname)
for field, fieldtype, fieldname in zip(fields, _fieldtypes, _fieldnames)
if field is not None
]
return indices_pb2.Metadata(id=id, fields=metadata_fields)
TYPE_TO_FIELD_TYPE = {str: "string", float: "double", int: "long", bool: "bool"}
[docs]def metadata_field_to_proto(
field: Union[str, int, float, bool],
fieldtype: Optional[str] = None,
fieldname: Optional[str] = None,
) -> indices_pb2.MetadataField:
"""Serialize some python value to a metadata field proto
Args:
field: Primative python value
fieldtype: Explicit type to serialize the field
fieldname: Optional name for this metadata field
"""
proto = indices_pb2.MetadataField(name=fieldname)
fieldtype = fieldtype if fieldtype else TYPE_TO_FIELD_TYPE.get(type(field))
if fieldtype is None:
raise SerializationError(f"Fieldtype {type(field)} not serializable.")
if fieldtype == "string" and isinstance(field, str):
proto.string_val = field
elif fieldtype == "double" and isinstance(field, float):
proto.double_val = field
elif fieldtype == "float" and isinstance(field, float):
proto.float_val = field
elif fieldtype == "long" and isinstance(field, int):
proto.long_val = field
elif fieldtype == "int" and isinstance(field, int):
proto.int_val = field
elif fieldtype == "bool" and isinstance(field, bool):
proto.bool_val = field
else:
raise SerializationError(
f"Fieldtype {fieldtype} and primative {type(field)} not serializable."
)
return proto