rulecache.py 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36 37
from lib import rule_source_directory

38 39 40 41

class RuleCache(object):
    """
    """
42

43 44
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
45
        self.cachefile = '%srules.sqlite' % rule_source_directory
46
        self._rule_fields = ['sid', 'msg', 'classtype', 'rev', 'gid', 'source', 'enabled', 'reference', 'action']
47 48 49 50 51 52
        self._rule_defaults = {'classtype': '##none##'}

    @staticmethod
    def list_local():
        all_rule_files = []
        for filename in glob.glob('%s*.rules' % rule_source_directory):
53 54 55 56
            all_rule_files.append(filename)

        return all_rule_files

57
    def list_rules(self, filename):
58 59 60 61 62 63
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
64
            rule_info_record = {'rule': rule, 'metadata': None}
65 66
            if rule.find('msg:') != -1:
                # define basic record
67 68
                record = {'enabled': True, 'source': filename.split('/')[-1]}
                if rule.strip()[0] == '#':
69
                    record['enabled'] = False
70 71 72
                    record['action'] = rule.strip()[1:].split(' ')[0].replace('#', '')
                else:
                    record['action'] = rule.strip().split(' ')[0]
73 74 75

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
76 77 78 79 80
                    fieldname = field[0:field.find(':')].strip()
                    fieldcontent = field[field.find(':') + 1:].strip()
                    if fieldname in self._rule_fields:
                        if fieldcontent[0] == '"':
                            content = fieldcontent[1:-1]
81
                        else:
82
                            content = fieldcontent
83

84
                        if fieldname in record:
85
                            # if same field repeats, put items in list
86 87 88
                            if type(record[fieldname]) != list:
                                record[fieldname] = [record[fieldname]]
                            record[fieldname].append(content)
89
                        else:
90
                            record[fieldname] = content
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

108
    def is_changed(self):
109 110 111 112 113
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
114
            all_rule_files = self.list_local()
115 116 117 118 119 120 121 122
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
123
                cur.execute('SELECT max(timestamp), max(files) FROM stats')
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
                results = cur.fetchall()
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1]:
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
141 142
        cur.execute('CREATE TABLE stats (timestamp number, files number)')
        cur.execute("""CREATE TABLE rules (sid number, msg TEXT, classtype TEXT,
143 144
                                           rev INTEGER, gid INTEGER, reference TEXT,
                                           enabled BOOLEAN, action text, source TEXT)""")
145

146 147
        last_mtime = 0
        all_rule_files = self.list_local()
148 149 150 151 152
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
153
            for rule_info_record in self.list_rules(filename=filename):
154 155
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
156 157

            cur.executemany('insert into rules(%(fieldnames)s) '
158 159 160
                            'values (%(fieldvalues)s)' % {'fieldnames': (','.join(self._rule_fields)),
                                                          'fieldvalues': ':' + (',:'.join(self._rule_fields))}, rules)
        cur.execute('INSERT INTO stats (timestamp,files) VALUES (?,?) ', (last_mtime, len(all_rule_files)))
161 162
        db.commit()

163
    def search(self, limit, offset, filter_txt, sort_by):
164 165 166
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
167 168
        :param filter_txt: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
        :param sort_by: order by, list of fields and possible asc/desc parameter
169 170
        :return: dict
        """
171
        result = {'rows': []}
172 173 174 175 176 177 178 179
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
            sql = 'select * from rules '
            sql_filters = {}

180
            for filtertag in shlex.split(filter_txt):
181 182 183 184
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
185
                else:
186 187 188 189 190 191
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
                    if fieldname in self._rule_fields:
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
192
                            sql += 'cast(' + fieldname + " as text) like :" + fieldname + " "
193
                        else:
194
                            sql += 'cast(' + fieldname + " as text) like '%'|| :" + fieldname + " || '%' "
195 196 197 198 199 200 201
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
202
            sql_sort = []
203 204 205
            for sortField in sort_by.split(','):
                if sortField.split(' ')[0] in self._rule_fields:
                    if sortField.split(' ')[-1].lower() == 'desc':
206
                        sql_sort.append('%s desc' % sortField.split()[0])
207
                    else:
208
                        sql_sort.append('%s asc' % sortField.split()[0])
209

210
            # count total number of rows
211
            cur.execute('select count(*) from (%s) a' % sql, sql_filters)
212
            result['total_rows'] = cur.fetchall()[0][0]
213

214
            if len(sql_sort) > 0:
215
                sql += ' order by %s' % (','.join(sql_sort))
216

217
            if str(limit) != '0' and str(limit).isdigit():
218
                sql += ' limit %s' % limit
219
                if str(offset) != '0' and str(offset).isdigit():
220
                    sql += ' offset %s' % offset
221

222
            # fetch results
223
            cur.execute(sql, sql_filters)
224 225 226 227
            while True:
                row = cur.fetchone()
                if row is None:
                    break
228

229 230 231 232
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
233 234

        return result
235

236
    def list_class_types(self):
237 238 239 240
        """
        :return: list of installed classtypes
        """
        result = []
241 242 243
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
244
            cur.execute('SELECT DISTINCT classtype FROM rules')
245 246 247 248 249 250
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result