rulecache.py 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36 37
from lib import rule_source_directory

38 39 40 41

class RuleCache(object):
    """
    """
42

43 44
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
45
        self.cachefile = '%srules.sqlite' % rule_source_directory
46
        self._rule_fields = ['sid', 'msg', 'classtype', 'rev', 'gid', 'source', 'enabled', 'reference', 'action']
47 48 49 50 51 52
        self._rule_defaults = {'classtype': '##none##'}

    @staticmethod
    def list_local():
        all_rule_files = []
        for filename in glob.glob('%s*.rules' % rule_source_directory):
53 54 55 56
            all_rule_files.append(filename)

        return all_rule_files

57
    def list_rules(self, filename):
58 59 60 61 62 63
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
64
            rule_info_record = {'rule': rule, 'metadata': None}
65 66
            if rule.find('msg:') != -1:
                # define basic record
67 68
                record = {'enabled': True, 'source': filename.split('/')[-1]}
                if rule.strip()[0] == '#':
69
                    record['enabled'] = False
70 71 72
                    record['action'] = rule.strip()[1:].split(' ')[0].replace('#', '')
                else:
                    record['action'] = rule.strip().split(' ')[0]
73 74 75

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
76 77 78 79 80
                    fieldname = field[0:field.find(':')].strip()
                    fieldcontent = field[field.find(':') + 1:].strip()
                    if fieldname in self._rule_fields:
                        if fieldcontent[0] == '"':
                            content = fieldcontent[1:-1]
81
                        else:
82
                            content = fieldcontent
83

84
                        if fieldname in record:
85
                            # if same field repeats, put items in list
86 87 88
                            if type(record[fieldname]) != list:
                                record[fieldname] = [record[fieldname]]
                            record[fieldname].append(content)
89
                        else:
90
                            record[fieldname] = content
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

108
    def is_changed(self):
109 110 111 112 113
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
114
            all_rule_files = self.list_local()
115 116 117 118 119 120 121 122
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
123
                cur.execute('SELECT max(timestamp), max(files) FROM stats')
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
                results = cur.fetchall()
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1]:
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
141 142 143 144 145 146 147

        # if another process created the file, exit.
        cur.execute("select count(*) from sqlite_master where name = 'stats'")
        if cur.fetchall()[0][0] > 0:
            return None

        cur.execute("CREATE TABLE stats (timestamp number, files number)")
148
        cur.execute("""CREATE TABLE rules (sid number, msg TEXT, classtype TEXT,
149 150
                                           rev INTEGER, gid INTEGER, reference TEXT,
                                           enabled BOOLEAN, action text, source TEXT)""")
151

152 153
        last_mtime = 0
        all_rule_files = self.list_local()
154 155 156 157 158
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
159
            for rule_info_record in self.list_rules(filename=filename):
160 161
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
162 163

            cur.executemany('insert into rules(%(fieldnames)s) '
164 165 166
                            'values (%(fieldvalues)s)' % {'fieldnames': (','.join(self._rule_fields)),
                                                          'fieldvalues': ':' + (',:'.join(self._rule_fields))}, rules)
        cur.execute('INSERT INTO stats (timestamp,files) VALUES (?,?) ', (last_mtime, len(all_rule_files)))
167 168
        db.commit()

169
    def search(self, limit, offset, filter_txt, sort_by):
170 171 172
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
173 174
        :param filter_txt: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
        :param sort_by: order by, list of fields and possible asc/desc parameter
175 176
        :return: dict
        """
177
        result = {'rows': []}
178 179 180 181 182 183 184 185
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
            sql = 'select * from rules '
            sql_filters = {}

186
            for filtertag in shlex.split(filter_txt):
187 188 189 190
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
191
                else:
192 193 194 195 196 197
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
                    if fieldname in self._rule_fields:
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
198
                            sql += 'cast(' + fieldname + " as text) like :" + fieldname + " "
199
                        else:
200
                            sql += 'cast(' + fieldname + " as text) like '%'|| :" + fieldname + " || '%' "
201 202 203 204 205 206 207
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
208
            sql_sort = []
209 210 211
            for sortField in sort_by.split(','):
                if sortField.split(' ')[0] in self._rule_fields:
                    if sortField.split(' ')[-1].lower() == 'desc':
212
                        sql_sort.append('%s desc' % sortField.split()[0])
213
                    else:
214
                        sql_sort.append('%s asc' % sortField.split()[0])
215

216
            # count total number of rows
217
            cur.execute('select count(*) from (%s) a' % sql, sql_filters)
218
            result['total_rows'] = cur.fetchall()[0][0]
219

220
            if len(sql_sort) > 0:
221
                sql += ' order by %s' % (','.join(sql_sort))
222

223
            if str(limit) != '0' and str(limit).isdigit():
224
                sql += ' limit %s' % limit
225
                if str(offset) != '0' and str(offset).isdigit():
226
                    sql += ' offset %s' % offset
227

228
            # fetch results
229
            cur.execute(sql, sql_filters)
230 231 232 233
            while True:
                row = cur.fetchone()
                if row is None:
                    break
234

235 236 237 238
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
239 240

        return result
241

242
    def list_class_types(self):
243 244 245 246
        """
        :return: list of installed classtypes
        """
        result = []
247 248 249
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
250
            cur.execute('SELECT DISTINCT classtype FROM rules')
251 252 253 254 255 256
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result