rulecache.py 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36
import fcntl
37 38
from lib import rule_source_directory

39 40 41 42

class RuleCache(object):
    """
    """
43

44 45
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
46
        self.cachefile = '%srules.sqlite' % rule_source_directory
47
        self._rule_fields = ['sid', 'msg', 'classtype', 'rev', 'gid', 'source', 'enabled', 'reference', 'action']
48 49 50 51 52 53
        self._rule_defaults = {'classtype': '##none##'}

    @staticmethod
    def list_local():
        all_rule_files = []
        for filename in glob.glob('%s*.rules' % rule_source_directory):
54 55 56 57
            all_rule_files.append(filename)

        return all_rule_files

58
    def list_rules(self, filename):
59 60 61 62 63 64
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
65
            rule_info_record = {'rule': rule, 'metadata': None}
66 67
            if rule.find('msg:') != -1:
                # define basic record
68 69
                record = {'enabled': True, 'source': filename.split('/')[-1]}
                if rule.strip()[0] == '#':
70
                    record['enabled'] = False
71 72 73
                    record['action'] = rule.strip()[1:].split(' ')[0].replace('#', '')
                else:
                    record['action'] = rule.strip().split(' ')[0]
74 75 76

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
77 78 79 80 81
                    fieldname = field[0:field.find(':')].strip()
                    fieldcontent = field[field.find(':') + 1:].strip()
                    if fieldname in self._rule_fields:
                        if fieldcontent[0] == '"':
                            content = fieldcontent[1:-1]
82
                        else:
83
                            content = fieldcontent
84

85
                        if fieldname in record:
86
                            # if same field repeats, put items in list
87 88 89
                            if type(record[fieldname]) != list:
                                record[fieldname] = [record[fieldname]]
                            record[fieldname].append(content)
90
                        else:
91
                            record[fieldname] = content
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

109
    def is_changed(self):
110 111 112 113 114
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
115
            all_rule_files = self.list_local()
116 117 118 119 120 121 122 123
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
124
                cur.execute('SELECT max(timestamp), max(files) FROM stats')
125 126 127 128 129 130 131 132 133 134 135 136
                results = cur.fetchall()
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1]:
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
137 138 139
        # lock create process
        lock = open(self.cachefile + '.LCK', 'w')
        try:
Ad Schellevis's avatar
Ad Schellevis committed
140
            fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
141 142 143 144 145 146 147
        except IOError:
            # other process is already creating the cache, wait, let the other process do it's work and return.
            fcntl.flock(lock, fcntl.LOCK_EX)
            fcntl.flock(lock, fcntl.LOCK_UN)
            return

        # remove existing DB
148 149 150 151 152
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
153 154

        cur.execute("CREATE TABLE stats (timestamp number, files number)")
155
        cur.execute("""CREATE TABLE rules (sid number, msg TEXT, classtype TEXT,
156 157
                                           rev INTEGER, gid INTEGER, reference TEXT,
                                           enabled BOOLEAN, action text, source TEXT)""")
158

159 160
        last_mtime = 0
        all_rule_files = self.list_local()
161 162 163 164 165
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
166
            for rule_info_record in self.list_rules(filename=filename):
167 168
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
169 170

            cur.executemany('insert into rules(%(fieldnames)s) '
171 172 173
                            'values (%(fieldvalues)s)' % {'fieldnames': (','.join(self._rule_fields)),
                                                          'fieldvalues': ':' + (',:'.join(self._rule_fields))}, rules)
        cur.execute('INSERT INTO stats (timestamp,files) VALUES (?,?) ', (last_mtime, len(all_rule_files)))
174
        db.commit()
175 176
        # release lock
        fcntl.flock(lock, fcntl.LOCK_UN)
177

178
    def search(self, limit, offset, filter_txt, sort_by):
179 180 181
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
182 183
        :param filter_txt: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
        :param sort_by: order by, list of fields and possible asc/desc parameter
184 185
        :return: dict
        """
186
        result = {'rows': []}
187 188 189 190 191 192 193 194
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
            sql = 'select * from rules '
            sql_filters = {}

195
            for filtertag in shlex.split(filter_txt):
196 197 198 199
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
200
                else:
201 202 203 204 205 206
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
                    if fieldname in self._rule_fields:
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
207
                            sql += 'cast(' + fieldname + " as text) like :" + fieldname + " "
208
                        else:
209
                            sql += 'cast(' + fieldname + " as text) like '%'|| :" + fieldname + " || '%' "
210 211 212 213 214 215 216
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
217
            sql_sort = []
218 219 220
            for sortField in sort_by.split(','):
                if sortField.split(' ')[0] in self._rule_fields:
                    if sortField.split(' ')[-1].lower() == 'desc':
221
                        sql_sort.append('%s desc' % sortField.split()[0])
222
                    else:
223
                        sql_sort.append('%s asc' % sortField.split()[0])
224

225
            # count total number of rows
226
            cur.execute('select count(*) from (%s) a' % sql, sql_filters)
227
            result['total_rows'] = cur.fetchall()[0][0]
228

229
            if len(sql_sort) > 0:
230
                sql += ' order by %s' % (','.join(sql_sort))
231

232
            if str(limit) != '0' and str(limit).isdigit():
233
                sql += ' limit %s' % limit
234
                if str(offset) != '0' and str(offset).isdigit():
235
                    sql += ' offset %s' % offset
236

237
            # fetch results
238
            cur.execute(sql, sql_filters)
239 240 241 242
            while True:
                row = cur.fetchone()
                if row is None:
                    break
243

244 245 246 247
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
248 249

        return result
250

251
    def list_class_types(self):
252 253 254 255
        """
        :return: list of installed classtypes
        """
        result = []
256 257 258
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
259
            cur.execute('SELECT DISTINCT classtype FROM rules')
260 261 262 263 264 265
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result