aboutsummaryrefslogtreecommitdiffstats
path: root/test/subprocesstest.py
blob: 9acf7d8c9d7cef6a8d74f7d7abbc6de17aa1f274 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#
# Wireshark tests
# By Gerald Combs <gerald@wireshark.org>
#
# Ported from a set of Bash scripts which were copyright 2005 Ulf Lamping
#
# SPDX-License-Identifier: GPL-2.0-or-later
#
'''Subprocess test case superclass'''

import difflib
import io
import os
import os.path
import re
import subprocess
import sys
import unittest

# To do:
# - Add a subprocesstest.SkipUnlessCapture decorator?
# - Try to catch crashes? See the comments below in waitProcess.

process_timeout = 300 # Seconds

def cat_dhcp_command(mode):
    '''Create a command string for dumping dhcp.pcap to stdout'''
    # XXX Do this in Python in a thread?
    sd_cmd = ''
    if sys.executable:
        sd_cmd = '"{}" '.format(sys.executable)
    this_dir = os.path.dirname(__file__)
    sd_cmd += os.path.join(this_dir, 'util_dump_dhcp_pcap.py ' + mode)
    return sd_cmd

def cat_cap_file_command(cap_files):
    '''Create a command string for dumping one or more capture files to stdout'''
    # XXX Do this in Python in a thread?
    if isinstance(cap_files, str):
        cap_files = [ cap_files ]
    quoted_paths = ' '.join('"{}"'.format(cap_file) for cap_file in cap_files)
    if sys.platform.startswith('win32'):
        # https://docs.microsoft.com/en-us/previous-versions/windows/it-pro/windows-xp/bb491026(v=technet.10)
        # says that the `type` command "displays the contents of a text
        # file." Copy to the console instead.
        return 'copy {} CON'.format(quoted_paths)
    return 'cat {}'.format(quoted_paths)

class LoggingPopen(subprocess.Popen):
    '''Run a process using subprocess.Popen. Capture and log its output.

    Stdout and stderr are captured to memory and decoded as UTF-8. On
    Windows, CRLF line endings are normalized to LF. The program command
    and output is written to log_fd.
    '''
    def __init__(self, proc_args, *args, **kwargs):
        self.log_fd = kwargs.pop('log_fd', None)
        self.max_lines = kwargs.pop('max_lines', None)
        kwargs['stdout'] = subprocess.PIPE
        kwargs['stderr'] = subprocess.PIPE
        # Make sure communicate() gives us bytes.
        kwargs['universal_newlines'] = False
        self.cmd_str = 'command ' + repr(proc_args)
        super().__init__(proc_args, *args, **kwargs)
        self.stdout_str = ''
        self.stderr_str = ''

    @staticmethod
    def trim_output(out_log, max_lines):
        lines = out_log.splitlines(True)
        if not len(lines) > max_lines * 2 + 1:
            return out_log
        header = lines[:max_lines]
        body = lines[max_lines:-max_lines]
        body = "<<< trimmed {} lines of output >>>\n".format(len(body))
        footer = lines[-max_lines:]
        return ''.join(header) + body + ''.join(footer)

    def wait_and_log(self):
        '''Wait for the process to finish and log its output.'''
        out_data, err_data = self.communicate(timeout=process_timeout)
        out_log = out_data.decode('UTF-8', 'replace')
        if self.max_lines and self.max_lines > 0:
            out_log = self.trim_output(out_log, self.max_lines)
        err_log = err_data.decode('UTF-8', 'replace')
        self.log_fd.flush()
        self.log_fd.write('-- Begin stdout for {} --\n'.format(self.cmd_str))
        self.log_fd.write(out_log)
        self.log_fd.write('-- End stdout for {} --\n'.format(self.cmd_str))
        self.log_fd.write('-- Begin stderr for {} --\n'.format(self.cmd_str))
        self.log_fd.write(err_log)
        self.log_fd.write('-- End stderr for {} --\n'.format(self.cmd_str))
        self.log_fd.flush()
        # Make sure our output is the same everywhere.
        # Throwing a UnicodeDecodeError exception here is arguably a good thing.
        self.stdout_str = out_data.decode('UTF-8', 'strict').replace('\r\n', '\n')
        self.stderr_str = err_data.decode('UTF-8', 'strict').replace('\r\n', '\n')

    def stop_process(self, kill=False):
        '''Stop the process immediately.'''
        if kill:
            super().kill()
        else:
            super().terminate()

    def terminate(self):
        '''Terminate the process. Do not log its output.'''
        # XXX Currently unused.
        self.stop_process(kill=False)

    def kill(self):
        '''Kill the process. Do not log its output.'''
        self.stop_process(kill=True)

class SubprocessTestCase(unittest.TestCase):
    '''Run a program and gather its stdout and stderr.'''

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.exit_ok = 0

        # See ws_exit_codes.h
        self.exit_command_line = 1
        self.exit_invalid_interface = 2
        self.exit_invalid_file_error = 3
        self.exit_invalid_filter_error = 4
        self.exit_invalid_capability = 5
        self.exit_iface_no_link_types = 6
        self.exit_iface_has_no_timestamp_types = 7
        self.exit_init_failed = 8
        self.exit_open_error = 9

        self.exit_code = None
        self.log_fname = None
        self.log_fd = None
        self.processes = []
        self.cleanup_files = []
        self.dump_files = []

    def log_fd_write_bytes(self, log_data):
        self.log_fd.write(log_data)

    def filename_from_id(self, filename):
        '''Generate a filename prefixed with our test ID.'''
        id_filename = self.id() + '.' + filename
        if id_filename not in self.cleanup_files:
            self.cleanup_files.append(id_filename)
        return id_filename

    def kill_processes(self):
        '''Kill any processes we've opened so far'''
        for proc in self.processes:
            try:
                proc.kill()
            except Exception:
                pass

    def setUp(self):
        """
        Set up a single test. Opens a log file and add it to the cleanup list.
        """
        self.processes = []
        self.log_fname = self.filename_from_id('log')
        # Our command line utilities generate UTF-8. The log file endcoding
        # needs to match that.
        # XXX newline='\n' works for now, but we might have to do more work
        # to handle line endings in the future.
        self.log_fd = io.open(self.log_fname, 'w', encoding='UTF-8', newline='\n')
        self.cleanup_files.append(self.log_fname)

    def _last_test_failed(self):
        """Check for non-skipped tests that resulted in errors."""
        # The test outcome is not available via the public unittest API, so
        # check a private property, "_outcome", set by unittest.TestCase.run.
        # It remains None when running in debug mode (`pytest --pdb`).
        # The property is available since Python 3.4 until at least Python 3.7.
        if self._outcome:
            if hasattr(self._outcome, 'errors'):
                # Python 3.4 - 3.10
                result = self.defaultTestResult()
                self._feedErrorsToResult(result, self._outcome.errors)
            else:
                # Python 3.11+
                result = self._outcome.result
            for test_case, exc_info in (result.errors + result.failures):
                if exc_info:
                    return True
        # No errors occurred or running in debug mode.
        return False

    def tearDown(self):
        """
        Tears down a single test. Kills stray processes and closes the log file.
        On errors, display the log contents. On success, remove temporary files.
        """
        self.kill_processes()
        self.log_fd.close()
        if self._last_test_failed():
            self.dump_files.append(self.log_fname)
            # Leave some evidence behind.
            self.cleanup_files = []
            print('\nProcess output for {}:'.format(self.id()))
            with io.open(self.log_fname, 'r', encoding='UTF-8', errors='backslashreplace') as log_fd:
                for line in log_fd:
                    sys.stdout.write(line)
        for filename in self.cleanup_files:
            try:
                os.unlink(filename)
            except OSError:
                pass
        self.cleanup_files = []

    def getCaptureInfo(self, capinfos_args=None, cap_file=None):
        '''Run capinfos on a capture file and log its output.

        capinfos_args must be a sequence.
        Default cap_file is <test id>.testout.pcap.'''
        # XXX convert users to use a new fixture instead of this function.
        cmd_capinfos = self._fixture_request.getfixturevalue('cmd_capinfos')
        if not cap_file:
            cap_file = self.filename_from_id('testout.pcap')
        self.log_fd.write('\nOutput of {0} {1}:\n'.format(cmd_capinfos, cap_file))
        capinfos_cmd = [cmd_capinfos]
        if capinfos_args is not None:
            capinfos_cmd += capinfos_args
        capinfos_cmd.append(cap_file)
        capinfos_data = subprocess.check_output(capinfos_cmd)
        capinfos_stdout = capinfos_data.decode('UTF-8', 'replace')
        self.log_fd.write(capinfos_stdout)
        return capinfos_stdout

    def checkPacketCount(self, num_packets, cap_file=None):
        '''Make sure a capture file contains a specific number of packets.'''
        got_num_packets = False
        capinfos_testout = self.getCaptureInfo(cap_file=cap_file)
        count_pat = r'Number of packets:\s+{}'.format(num_packets)
        if re.search(count_pat, capinfos_testout):
            got_num_packets = True
        self.assertTrue(got_num_packets, 'Failed to capture exactly {} packets'.format(num_packets))

    def countOutput(self, search_pat=None, count_stdout=True, count_stderr=False, proc=None):
        '''Returns the number of output lines (search_pat=None), otherwise returns a match count.'''
        match_count = 0
        self.assertTrue(count_stdout or count_stderr, 'No output to count.')

        if proc is None:
            proc = self.processes[-1]

        out_data = ''
        if count_stdout:
            out_data = proc.stdout_str
        if count_stderr:
            out_data += proc.stderr_str

        if search_pat is None:
            return len(out_data.splitlines())

        search_re = re.compile(search_pat)
        for line in out_data.splitlines():
            if search_re.search(line):
                match_count += 1

        return match_count

    def grepOutput(self, search_pat, proc=None):
        return self.countOutput(search_pat, count_stderr=True, proc=proc) > 0

    def diffOutput(self, blob_a, blob_b, *args, **kwargs):
        '''Check for differences between blob_a and blob_b. Return False and log a unified diff if they differ.

        blob_a and blob_b must be UTF-8 strings.'''
        lines_a = blob_a.splitlines()
        lines_b = blob_b.splitlines()
        diff = '\n'.join(list(difflib.unified_diff(lines_a, lines_b, *args, **kwargs)))
        if len(diff) > 0:
            self.log_fd.flush()
            self.log_fd.write('-- Begin diff output --\n')
            self.log_fd.writelines(diff)
            self.log_fd.write('-- End diff output --\n')
            return False
        return True

    def startProcess(self, proc_args, stdin=None, env=None, shell=False, cwd=None, max_lines=None):
        '''Start a process in the background. Returns a subprocess.Popen object.

        You typically wait for it using waitProcess() or assertWaitProcess().'''
        if env is None:
            # Apply default test environment if no override is provided.
            env = getattr(self, 'injected_test_env', None)
            # Not all tests need test_env, but those that use runProcess or
            # startProcess must either pass an explicit environment or load the
            # fixture (via a test method parameter or class decorator).
            assert not (env is None and hasattr(self, '_fixture_request')), \
                "Decorate class with @fixtures.mark_usefixtures('test_env')"
        proc = LoggingPopen(proc_args, stdin=stdin, env=env, shell=shell, log_fd=self.log_fd, cwd=cwd, max_lines=max_lines)
        self.processes.append(proc)
        return proc

    def waitProcess(self, process):
        '''Wait for a process to finish.'''
        process.wait_and_log()
        # XXX The shell version ran processes using a script called run_and_catch_crashes
        # which looked for core dumps and printed stack traces if found. We might want
        # to do something similar here. This may not be easy on modern Ubuntu systems,
        # which default to using Apport: https://wiki.ubuntu.com/Apport

    def assertWaitProcess(self, process, expected_return=0):
        '''Wait for a process to finish and check its exit code.'''
        process.wait_and_log()
        self.assertEqual(process.returncode, expected_return)

    def runProcess(self, args, env=None, shell=False, cwd=None, max_lines=None):
        '''Start a process and wait for it to finish.'''
        process = self.startProcess(args, env=env, shell=shell, cwd=cwd, max_lines=max_lines)
        process.wait_and_log()
        return process

    def assertRun(self, args, env=None, shell=False, expected_return=0, cwd=None, max_lines=None):
        '''Start a process and wait for it to finish. Check its return code.'''
        process = self.runProcess(args, env=env, shell=shell, cwd=cwd, max_lines=max_lines)
        self.assertEqual(process.returncode, expected_return)
        return process