aboutsummaryrefslogtreecommitdiffstats
path: root/test/test.py
blob: 78421bd911f647788508eb9fb4bffa38c0a86983 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Wireshark tests
# By Gerald Combs <gerald@wireshark.org>
#
# Ported from a set of Bash scripts which were copyright 2005 Ulf Lamping
#
# SPDX-License-Identifier: GPL-2.0-or-later
#
'''Main test script'''

# To do:
# - Avoid printing Python tracebacks when we assert? It looks like we'd need
#   to override unittest.TextTestResult.addFailure().
# - Switch to Python 3 only? [Windows, Linux, macOS] x [Python 2, Python 3]
#   is painful.
# - Remove BIN_PATH/hosts via config.tearDownHostFiles + case_name_resolution.tearDownClass?


import argparse
import config
import os.path
import sys
import unittest

def find_test_ids(suite, all_ids):
    if hasattr(suite, '__iter__'):
        for s in suite:
            find_test_ids(s, all_ids)
    else:
        all_ids.append(suite.id())

def dump_failed_output(suite):
    if hasattr(suite, '__iter__'):
        for s in suite:
            dump_failures = getattr(s, 'dump_failures', None)
            if dump_failures:
                dump_failures()
            else:
                dump_failed_output(s)

def main():
    parser = argparse.ArgumentParser(description='Wireshark unit tests')
    cap_group = parser.add_mutually_exclusive_group()
    cap_group.add_argument('-e', '--enable-capture', action='store_true', help='Enable capture tests')
    cap_group.add_argument('-E', '--disable-capture', action='store_true', help='Disable capture tests')
    cap_group.add_argument('-i', '--capture-interface', nargs=1, default=None, help='Capture interface index or name')
    parser.add_argument('-p', '--program-path', nargs=1, default=os.path.curdir, help='Path to Wireshark executables.')
    list_group = parser.add_mutually_exclusive_group()
    list_group.add_argument('-l', '--list', action='store_true', help='List tests. One of "all" or a full or partial test name.')
    list_group.add_argument('--list-suites', action='store_true', help='List all suites.')
    list_group.add_argument('--list-cases', action='store_true', help='List all suites and cases.')
    parser.add_argument('-v', '--verbose', action='store_const', const=2, default=1, help='Verbose tests.')
    parser.add_argument('tests_to_run', nargs='*', metavar='test', default=['all'], help='Tests to run. One of "all" or a full or partial test name. Default is "all".')
    args = parser.parse_args()

    if args.enable_capture:
        config.setCanCapture(True)
    elif args.disable_capture:
        config.setCanCapture(False)

    if args.capture_interface:
        config.setCaptureInterface(args.capture_interface[0])

    all_tests = unittest.defaultTestLoader.discover(os.path.dirname(__file__), pattern='suite_*')

    all_ids = []
    find_test_ids(all_tests, all_ids)

    run_ids = []
    for tid in all_ids:
        for ttr in args.tests_to_run:
            ttrl = ttr.lower()
            if ttrl == 'all':
                run_ids = all_ids
                break
            if ttrl in tid.lower():
                run_ids.append(tid)

    if not run_ids:
        print('No tests found. You asked for:\n  ' + '\n  '.join(args.tests_to_run))
        parser.print_usage()
        sys.exit(1)

    if args.list:
        print('\n'.join(run_ids))
        sys.exit(0)

    if args.list_suites:
        suites = set()
        for rid in run_ids:
            rparts = rid.split('.')
            suites |= {rparts[0]}
        print('\n'.join(list(suites)))
        sys.exit(0)

    if args.list_cases:
        cases = set()
        for rid in run_ids:
            rparts = rid.split('.')
            cases |= {'.'.join(rparts[:2])}
        print('\n'.join(list(cases)))
        sys.exit(0)

    program_path = args.program_path[0]
    if not config.setProgramPath(program_path):
        print('One or more required executables not found at {}\n'.format(program_path))
        parser.print_usage()
        sys.exit(1)

    run_suite = unittest.defaultTestLoader.loadTestsFromNames(run_ids)
    runner = unittest.TextTestRunner(verbosity=args.verbose)
    test_result = runner.run(run_suite)

    dump_failed_output(run_suite)

    if test_result.errors:
        sys.exit(2)

    if test_result.failures:
        sys.exit(1)

if __name__ == '__main__':
    main()