1616 'uint32' : 'UINT32' ,
1717 'uint64' : 'UINT64' }
1818
19+ allowed_devices = {'CPU' , 'GPU' }
20+ allowed_backends = {'TF' , 'TFLITE' , 'TORCH' , 'ONNX' }
21+
1922
2023def numpy2blob (tensor : np .ndarray ) -> tuple :
21- """ Convert the numpy input from user to `Tensor` """
24+ """Convert the numpy input from user to `Tensor`. """
2225 try :
2326 dtype = dtype_dict [str (tensor .dtype )]
2427 except KeyError :
@@ -29,7 +32,7 @@ def numpy2blob(tensor: np.ndarray) -> tuple:
2932
3033
3134def blob2numpy (value : ByteString , shape : Union [list , tuple ], dtype : str ) -> np .ndarray :
32- """ Convert `BLOB` result from RedisAI to `np.ndarray` """
35+ """Convert `BLOB` result from RedisAI to `np.ndarray`. """
3336 mm = {
3437 'FLOAT' : 'float32' ,
3538 'DOUBLE' : 'float64'
@@ -40,6 +43,7 @@ def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.n
4043
4144
4245def list2dict (lst ):
46+ """Convert the list from RedisAI to a dict."""
4347 if len (lst ) % 2 != 0 :
4448 raise RuntimeError ("Can't unpack the list: {}" .format (lst ))
4549 out = {}
@@ -55,10 +59,8 @@ def list2dict(lst):
5559def recursive_bytetransform (arr : List [AnyStr ], target : Callable ) -> list :
5660 """
5761 Recurse value, replacing each element of b'' with the appropriate element.
58- Function returns the same array after inplace operation which updates `arr`
5962
60- :param target: Type of tensor | array
61- :param arr: The array with b'' numbers or recursive array of b''
63+ Function returns the same array after inplace operation which updates `arr`
6264 """
6365 for ix in range (len (arr )):
6466 obj = arr [ix ]
@@ -70,10 +72,16 @@ def recursive_bytetransform(arr: List[AnyStr], target: Callable) -> list:
7072
7173
7274def listify (inp : Union [str , Sequence [str ]]) -> Sequence [str ]:
75+ """Wrap the ``inp`` with a list if it's not a list already."""
7376 return (inp ,) if not isinstance (inp , (list , tuple )) else inp
7477
7578
76- def tensorget_postprocessor (as_numpy , meta_only , rai_result ):
79+ def tensorget_postprocessor (rai_result , as_numpy , meta_only ):
80+ """Process the tensorget output.
81+
82+ If ``as_numpy`` is True, it'll be converted to a numpy array. The required
83+ information such as datatype and shape must be in ``rai_result`` itself.
84+ """
7785 rai_result = list2dict (rai_result )
7886 if meta_only :
7987 return rai_result
0 commit comments