rulecache.py 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36
import fcntl
37
from ConfigParser import ConfigParser
38 39
from lib import rule_source_directory

40 41 42 43

class RuleCache(object):
    """
    """
44

45 46
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
47
        self.cachefile = '%srules.sqlite' % rule_source_directory
48
        self._rule_fields = ['sid', 'msg', 'classtype', 'rev', 'gid', 'source', 'enabled', 'reference', 'action']
49 50 51 52 53 54
        self._rule_defaults = {'classtype': '##none##'}

    @staticmethod
    def list_local():
        all_rule_files = []
        for filename in glob.glob('%s*.rules' % rule_source_directory):
55 56 57 58
            all_rule_files.append(filename)

        return all_rule_files

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    @staticmethod
    def list_local_changes():
        # parse OPNsense rule config
        rule_config_fn = ('%s../rules.config' % rule_source_directory)
        rule_config_mtime = os.stat(rule_config_fn).st_mtime
        rule_updates = {}
        if os.path.exists(rule_config_fn):
            cnf = ConfigParser()
            cnf.read(rule_config_fn)
            for section in cnf.sections():
                if section[0:5] == 'rule_':
                    sid = section[5:]
                    rule_updates[sid] = {'mtime': rule_config_mtime}
                    for rule_item in cnf.items(section):
                        rule_updates[sid][rule_item[0]] = rule_item[1]
        return rule_updates


77
    def list_rules(self, filename):
78 79 80 81 82 83
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
84
            rule_info_record = {'rule': rule, 'metadata': None}
85 86
            if rule.find('msg:') != -1:
                # define basic record
87 88
                record = {'enabled': True, 'source': filename.split('/')[-1]}
                if rule.strip()[0] == '#':
89
                    record['enabled'] = False
90 91 92
                    record['action'] = rule.strip()[1:].split(' ')[0].replace('#', '')
                else:
                    record['action'] = rule.strip().split(' ')[0]
93 94 95

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
96 97 98 99 100
                    fieldname = field[0:field.find(':')].strip()
                    fieldcontent = field[field.find(':') + 1:].strip()
                    if fieldname in self._rule_fields:
                        if fieldcontent[0] == '"':
                            content = fieldcontent[1:-1]
101
                        else:
102
                            content = fieldcontent
103

104
                        if fieldname in record:
105
                            # if same field repeats, put items in list
106 107 108
                            if type(record[fieldname]) != list:
                                record[fieldname] = [record[fieldname]]
                            record[fieldname].append(content)
109
                        else:
110
                            record[fieldname] = content
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

128
    def is_changed(self):
129 130 131 132 133
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
134
            all_rule_files = self.list_local()
135 136 137 138 139 140 141 142
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
143 144
                cur.execute("select count(*) from sqlite_master WHERE type='table'")
                table_count = cur.fetchall()[0][0]
145
                cur.execute('SELECT max(timestamp), max(files) FROM stats')
146
                results = cur.fetchall()
147
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1] and table_count == 3:
148 149 150 151 152 153 154 155 156 157
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
158 159 160
        # lock create process
        lock = open(self.cachefile + '.LCK', 'w')
        try:
Ad Schellevis's avatar
Ad Schellevis committed
161
            fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
162 163 164 165 166 167 168
        except IOError:
            # other process is already creating the cache, wait, let the other process do it's work and return.
            fcntl.flock(lock, fcntl.LOCK_EX)
            fcntl.flock(lock, fcntl.LOCK_UN)
            return

        # remove existing DB
169 170 171 172 173
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
174

175 176
        cur.execute("create table stats (timestamp number, files number)")
        cur.execute("""create table rules (sid number, msg TEXT, classtype TEXT,
177 178
                                           rev INTEGER, gid INTEGER, reference TEXT,
                                           enabled BOOLEAN, action text, source TEXT)""")
179
        cur.execute("create table local_rule_changes(sid number primary key, action text, last_mtime number)")
180 181
        last_mtime = 0
        all_rule_files = self.list_local()
182 183 184 185 186
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
187
            for rule_info_record in self.list_rules(filename=filename):
188 189
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
190 191

            cur.executemany('insert into rules(%(fieldnames)s) '
192 193 194
                            'values (%(fieldvalues)s)' % {'fieldnames': (','.join(self._rule_fields)),
                                                          'fieldvalues': ':' + (',:'.join(self._rule_fields))}, rules)
        cur.execute('INSERT INTO stats (timestamp,files) VALUES (?,?) ', (last_mtime, len(all_rule_files)))
195
        db.commit()
196 197
        # release lock
        fcntl.flock(lock, fcntl.LOCK_UN)
198

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    def update_local_changes(self):
        """ read local rules.config containing changes on installed ruleset and update to "local_rule_changes" table
        """
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
            cur.execute('select max(last_mtime) from local_rule_changes')
            last_mtime = cur.fetchall()[0][0]
            rule_config_mtime = os.stat(('%s../rules.config' % rule_source_directory)).st_mtime
            if rule_config_mtime != last_mtime:
                # make sure only one process is updating this table
                lock = open(self.cachefile + '.LCK', 'w')
                try:
                    fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
                except IOError:
                    # other process is already creating the cache, wait, let the other process do it's work and return.
                    fcntl.flock(lock, fcntl.LOCK_EX)
                    fcntl.flock(lock, fcntl.LOCK_UN)
                    return
                # delete and insert local changes
                cur.execute('delete from local_rule_changes')
                local_changes = self.list_local_changes()
                for sid in local_changes:
                    sql_params = (sid, local_changes[sid]['action'], local_changes[sid]['mtime'])
                    cur.execute('insert into local_rule_changes(sid, action, last_mtime) values (?,?,?)', sql_params)
                db.commit()
                # release lock
                fcntl.flock(lock, fcntl.LOCK_UN)


229
    def search(self, limit, offset, filter_txt, sort_by):
230 231 232
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
233 234
        :param filter_txt: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
        :param sort_by: order by, list of fields and possible asc/desc parameter
235 236
        :return: dict
        """
237
        result = {'rows': []}
238 239 240 241 242
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
243 244 245 246 247 248 249
            sql = """select *
                     from (
                         select rules.*, case when rc.action is null then rules.action else rc.action end installed_action
                         from rules
                         left join local_rule_changes rc on rules.sid = rc.sid
                     ) a
                     """
250
            sql_filters = {}
251
            additional_search_fields = ['installed_action']
252
            for filtertag in shlex.split(filter_txt):
253 254 255 256
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
257
                else:
258 259
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
260
                    if fieldname in self._rule_fields or fieldname in additional_search_fields:
261 262 263
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
264
                            sql += 'cast(' + fieldname + " as text) like :" + fieldname + " "
265
                        else:
266
                            sql += 'cast(' + fieldname + " as text) like '%'|| :" + fieldname + " || '%' "
267 268 269 270 271 272 273
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
274
            sql_sort = []
275
            for sortField in sort_by.split(','):
276
                if sortField.split(' ')[0] in self._rule_fields or sortField.split(' ')[0] in additional_search_fields:
277
                    if sortField.split(' ')[-1].lower() == 'desc':
278
                        sql_sort.append('%s desc' % sortField.split()[0])
279
                    else:
280
                        sql_sort.append('%s asc' % sortField.split()[0])
281

282
            # count total number of rows
283
            cur.execute('select count(*) from (%s) a' % sql, sql_filters)
284
            result['total_rows'] = cur.fetchall()[0][0]
285

286
            if len(sql_sort) > 0:
287
                sql += ' order by %s' % (','.join(sql_sort))
288

289
            if str(limit) != '0' and str(limit).isdigit():
290
                sql += ' limit %s' % limit
291
                if str(offset) != '0' and str(offset).isdigit():
292
                    sql += ' offset %s' % offset
293

294
            # fetch results
295
            cur.execute(sql, sql_filters)
296 297 298 299
            while True:
                row = cur.fetchone()
                if row is None:
                    break
300

301 302 303 304
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
305 306

        return result
307

308
    def list_class_types(self):
309 310 311 312
        """
        :return: list of installed classtypes
        """
        result = []
313 314 315
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
316
            cur.execute('SELECT DISTINCT classtype FROM rules')
317 318 319 320 321 322
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result