1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
|
# -*- coding: utf-8 -*-
'''
Very simple (performance oriented) declarative message codec.
Inspired by Pycrate and Scapy.
'''
# TRX Toolkit
#
# (C) 2021 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
# Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
#
# All Rights Reserved
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
from typing import Optional, Callable, Tuple, Any
import abc
class ProtocolError(Exception):
''' Error in a protocol definition. '''
class DecodeError(Exception):
''' Error during decoding of a field/message. '''
class EncodeError(Exception):
''' Error during encoding of a field/message. '''
class Codec(abc.ABC):
''' Base class providing encoding and decoding API. '''
@abc.abstractmethod
def from_bytes(self, vals: dict, data: bytes) -> int:
''' Decode value(s) from the given buffer of bytes. '''
@abc.abstractmethod
def to_bytes(self, vals: dict) -> bytes:
''' Encode value(s) into bytes. '''
class Field(Codec):
''' Base class representing one field in a Message. '''
# Default length (0 means the whole buffer)
DEF_LEN = 0 # type: int
# Default parameters
DEF_PARAMS = { } # type: dict
# Presence of a field during decoding and encoding
## get_pres: Callable[[dict], bool]
# Length of a field for self.from_bytes()
## get_len: Callable[[dict, bytes], int]
# Value of a field for self.to_bytes()
## get_val: Callable[[dict], Any]
def __init__(self, name: str, **kw) -> None:
self.name = name
self.len = kw.get('len', self.DEF_LEN)
if self.len == 0: # flexible field
self.get_len = lambda _, data: len(data)
else: # fixed length
self.get_len = lambda vals, _: self.len
# Field is unconditionally present by default
self.get_pres = lambda vals: True
# Field takes its value from the given dict by default
self.get_val = lambda vals: vals[self.name]
# Additional parameters for derived field types
self.p = { key : kw.get(key, self.DEF_PARAMS[key])
for key in self.DEF_PARAMS }
def from_bytes(self, vals: dict, data: bytes) -> int:
if self.get_pres(vals) is False:
return 0
length = self.get_len(vals, data)
if len(data) < length:
raise DecodeError('Short read')
self._from_bytes(vals, data[:length])
return length
def to_bytes(self, vals: dict) -> bytes:
if self.get_pres(vals) is False:
return b''
data = self._to_bytes(vals)
if self.len > 0 and len(data) != self.len:
raise EncodeError('Field length mismatch')
return data
@abc.abstractmethod
def _from_bytes(self, vals: dict, data: bytes) -> None:
''' Decode value(s) from the given buffer of bytes. '''
raise NotImplementedError
@abc.abstractmethod
def _to_bytes(self, vals: dict) -> bytes:
''' Encode value(s) into bytes. '''
raise NotImplementedError
class Buf(Field):
''' A sequence of octets. '''
def _from_bytes(self, vals: dict, data: bytes) -> None:
vals[self.name] = data
def _to_bytes(self, vals: dict) -> bytes:
# TODO: handle len(self.get_val()) < self.get_len()
return self.get_val(vals)
class Spare(Field):
''' Spare filling for RFU fields or padding. '''
# Default parameters
DEF_PARAMS = {
'filler' : b'\x00',
}
def _from_bytes(self, vals: dict, data: bytes) -> None:
pass # Just ignore it
def _to_bytes(self, vals: dict) -> bytes:
return self.p['filler'] * self.get_len(vals, b'')
class Uint(Field):
''' An integer field: unsigned, N bits, big endian. '''
# Uint8 by default
DEF_LEN = 1
# Default parameters
DEF_PARAMS = {
'offset' : 0,
'mult' : 1,
}
# Big endian, unsigned
SIGN = False
BO = 'big'
def _from_bytes(self, vals: dict, data: bytes) -> None:
val = int.from_bytes(data, self.BO, signed=self.SIGN)
vals[self.name] = val * self.p['mult'] + self.p['offset']
def _to_bytes(self, vals: dict) -> bytes:
val = (self.get_val(vals) - self.p['offset']) // self.p['mult']
return val.to_bytes(self.len, self.BO, signed=self.SIGN)
class Uint16BE(Uint):
DEF_LEN = 16 // 8
class Uint16LE(Uint16BE):
BO = 'little'
class Uint32BE(Uint):
DEF_LEN = 32 // 8
class Uint32LE(Uint32BE):
BO = 'little'
class Int(Uint):
SIGN = True
class Int16BE(Int):
DEF_LEN = 16 // 8
class Int16LE(Int16BE):
BO = 'little'
class Int32BE(Int):
DEF_LEN = 32 // 8
class Int32LE(Int32BE):
BO = 'little'
class BitFieldSet(Field):
''' A set of bit-fields. '''
# Default parameters
DEF_PARAMS = {
# Default field order (MSB first)
'order' : 'big',
}
# To be defined by derived types
STRUCT = () # type: Tuple['BitField', ...]
def __init__(self, **kw) -> None:
Field.__init__(self, self.__class__.__name__, **kw)
self._fields = kw.get('set', self.STRUCT)
if type(self._fields) is not tuple:
raise ProtocolError('Expected a tuple')
# LSB first is basically reversed order
if self.p['order'] in ('little', 'lsb'):
self._fields = self._fields[::-1]
# Calculate the overall field length
if self.len == 0:
bl_sum = sum([f.bl for f in self._fields])
self.len = bl_sum // 8
if bl_sum % 8 > 0:
self.len += 1
# Re-define self.get_len() since we always know the length
self.get_len = lambda vals, data: self.len
# Pre-calculate offset and mask for each field
offset = self.len * 8
for f in self._fields:
if f.bl > offset:
raise ProtocolError(f, 'BitFieldSet overflow')
f.offset = offset - f.bl
f.mask = 2 ** f.bl - 1
offset -= f.bl
def _from_bytes(self, vals: dict, data: bytes) -> None:
blob = int.from_bytes(data, byteorder='big') # intentionally using 'big' here
for f in self._fields:
f.dec_val(vals, blob)
def _to_bytes(self, vals: dict) -> bytes:
blob = 0x00
for f in self._fields: # TODO: use functools.reduce()?
blob |= f.enc_val(vals)
return blob.to_bytes(self.len, byteorder='big')
class BitField:
''' One field in a BitFieldSet. '''
# Special fields for BitFieldSet
offset = 0 # type: int
mask = 0 # type: int
class Spare:
''' Spare filling in a BitFieldSet. '''
def __init__(self, bl: int) -> None:
self.name = None
self.bl = bl
def enc_val(self, vals: dict) -> int:
return 0
def dec_val(self, vals: dict, blob: int) -> None:
pass # Just ignore it
def __init__(self, name: str, bl: int, **kw) -> None:
if bl < 1: # Ensure proper length
raise ProtocolError('Incorrect bit-field length')
self.name = name
self.bl = bl
# (Optional) fixed value for encoding and decoding
self.val = kw.get('val', None) # type: Optional[int]
def enc_val(self, vals: dict) -> int:
if self.val is None:
val = vals[self.name]
else:
val = self.val
return (val & self.mask) << self.offset
def dec_val(self, vals: dict, blob: int) -> None:
vals[self.name] = (blob >> self.offset) & self.mask
if (self.val is not None) and (vals[self.name] != self.val):
raise DecodeError('Unexpected value %d, expected %d'
% (vals[self.name], self.val))
class Envelope:
''' A group of related fields. '''
STRUCT = () # type: Tuple[Codec, ...]
def __init__(self, check_len: bool = True):
# TODO: ensure uniqueue field names in self.STRUCT
self.c = { } # type: dict
self.check_len = check_len
def __getitem__(self, key: str) -> Any:
return self.c[key]
def __setitem__(self, key: str, val: Any) -> None:
self.c[key] = val
def __delitem__(self, key: str) -> None:
del self.c[key]
def check(self, vals: dict) -> None:
''' Check the content before encoding and after decoding.
Raise exceptions (e.g. ValueError) if something is wrong.
Do not assert for every possible error (e.g. a negative value
for a Uint field) if an exception will be thrown by the field's
to_bytes() method anyway. Only additional constraints here.
'''
def from_bytes(self, data: bytes) -> int:
self.c.clear() # forget the old content
return self._from_bytes(self.c, data)
def to_bytes(self) -> bytes:
return self._to_bytes(self.c)
def _from_bytes(self, vals: dict, data: bytes, offset: int = 0) -> int:
try: # Fields throw exceptions
for f in self.STRUCT:
offset += f.from_bytes(vals, data[offset:])
except Exception as e:
# Add contextual info
raise DecodeError(self, f, offset) from e
if self.check_len and len(data) != offset:
raise DecodeError(self, 'Unhandled tail octets: %s'
% data[offset:].hex())
self.check(vals) # Check the content after decoding (raises exceptions)
return offset
def _to_bytes(self, vals: dict) -> bytes:
def proc(f: Codec):
try: # Fields throw exceptions
return f.to_bytes(vals)
except Exception as e:
# Add contextual info
raise EncodeError(self, f) from e
self.check(vals) # Check the content before encoding (raises exceptions)
return b''.join([proc(f) for f in self.STRUCT])
class F(Field):
''' Field wrapper. '''
def __init__(self, e: 'Envelope', name: str, **kw) -> None:
Field.__init__(self, name, **kw)
self.e = e
def _from_bytes(self, vals: dict, data: bytes) -> None:
vals[self.name] = { }
self.e._from_bytes(vals[self.name], data)
def _to_bytes(self, vals: dict) -> bytes:
return self.e._to_bytes(self.get_val(vals))
def f(self, name: str, **kw) -> Field:
return self.F(self, name, **kw)
class Sequence:
''' A sequence of repeating elements (e.g. TLVs). '''
# The item of sequence
ITEM = None # type: Optional[Envelope]
def __init__(self, **kw) -> None:
if (self.ITEM is None) and ('item' not in kw):
raise ProtocolError('Missing Sequence item')
self._item = kw.get('item', self.ITEM) # type: Envelope
self._item.check_len = False
def from_bytes(self, data: bytes) -> list:
proc = self._item._from_bytes
vseq, offset = [], 0
length = len(data)
while offset < length:
vseq.append({ }) # new item of sequence
offset += proc(vseq[-1], data[offset:])
return vseq
def to_bytes(self, vseq: list) -> bytes:
proc = self._item._to_bytes
return b''.join([proc(v) for v in vseq])
class F(Field):
''' Field wrapper. '''
def __init__(self, s: 'Sequence', name: str, **kw) -> None:
Field.__init__(self, name, **kw)
self.s = s
def _from_bytes(self, vals: dict, data: bytes) -> None:
vals[self.name] = self.s.from_bytes(data)
def _to_bytes(self, vals: dict) -> bytes:
return self.s.to_bytes(self.get_val(vals))
def f(self, name: str, **kw) -> Field:
return self.F(self, name, **kw)
|