inputs: regroup the decoding logic into the input handling objects
2 files changed, 48 insertions(+), 29 deletions(-)

M rework/helper.py
M rework/input.py
M rework/helper.py +11 -24
@@ 17,7 17,6 @@ import psutil
 from sqlalchemy.engine import url
 from sqlhelp import select
 from inireader import reader
-from dateutil.parser import isoparse, parse as defaultparse
 
 from rework.input import inputio
 

          
@@ 439,7 438,11 @@ def pack_inputs(spec, args):
 
 
 def nary_unpack(packedbytes):
-    packedbytes = zstd.decompress(packedbytes.tobytes())
+    try:
+        packedbytes = zstd.decompress(packedbytes.tobytes())
+    except zstd.Error:
+        raise TypeError('wrong input format')
+
     [sizes_size] = struct.unpack(
         '!L', packedbytes[:4]
     )

          
@@ 452,13 455,6 @@ def nary_unpack(packedbytes):
     return struct.unpack(fmt, packedbytes[payloadoffset:])
 
 
-def parsedatetime(strdt):
-    try:
-        return isoparse(strdt)
-    except ValueError:
-        return defaultparse(strdt)
-
-
 def unpack_inputs(spec, packedbytes):
     byteslist = nary_unpack(packedbytes)
     middle = len(byteslist) // 2

          
@@ 470,23 466,14 @@ def unpack_inputs(spec, packedbytes):
     output = dict(zip(keys, values))
 
     for field in spec:
-        name = field['name']
-        val = output.get(name)
+        inp = inputio.from_type(
+            field['type'], field['name'], field['required'], field['choices']
+        )
+        val = inp.binary_decode(output)
         if val is None:
             continue
-        ftype = field['type']
-        if ftype == 'number':
-            try:
-                val = int(val)
-            except ValueError:
-                val = float(val)
-            output[name] = val
-            continue
-        if ftype == 'string':
-            output[name] = val.decode('utf-8')
-            continue
-        if ftype == 'datetime':
-            output[name] = parsedatetime(val.decode('utf-8'))
+
+        output[inp.name] = val
 
     return output
 

          
M rework/input.py +37 -5
@@ 1,5 1,14 @@ 
 import json
 
+from dateutil.parser import isoparse, parse as defaultparse
+
+
+def parsedatetime(strdt):
+    try:
+        return isoparse(strdt)
+    except ValueError:
+        return defaultparse(strdt)
+
 
 class inputio:
     _fields = 'name', 'required', 'choices'

          
@@ 43,6 52,15 @@ class number(inputio):
         if val is not None:
             return str(val).encode('utf-8')
 
+    def binary_decode(self, args):
+        val = args.get(self.name)
+        if val is None:
+            return
+        try:
+            return int(val)
+        except ValueError:
+            return float(val)
+
 
 class string(inputio):
 

          
@@ 51,15 69,19 @@ class string(inputio):
         if val is not None:
             return val.encode('utf-8')
 
+    def binary_decode(self, args):
+        val = args.get(self.name)
+        if val is not None:
+            return val.decode('utf-8')
+
 
 class file(inputio):
 
     def binary_encode(self, args):
-        val = self.val(args)
-        if val is None:
-            return
-        assert isinstance(val, bytes) or val is None
-        return val
+        return self.val(args)
+
+    def binary_decode(self, args):
+        return args.get(self.name)
 
 
 class datetime(inputio):

          
@@ 73,3 95,13 @@ class datetime(inputio):
         else:
             val = val.isoformat().encode('utf-8')
         return val
+
+    def binary_decode(self, args):
+        val = args.get(self.name)
+        if val is None:
+            return
+        val = val.decode('utf-8')
+        try:
+            return isoparse(val)
+        except ValueError:
+            return defaultparse(val)