rulecache.py 10 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36 37
from lib import rule_source_directory

38 39 40 41

class RuleCache(object):
    """
    """
42

43 44
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
45 46 47 48 49 50 51 52
        self.cachefile = '%srules.sqlite' % rule_source_directory
        self._rule_fields = ['sid', 'msg', 'classtype', 'rev', 'gid', 'source', 'enabled', 'reference']
        self._rule_defaults = {'classtype': '##none##'}

    @staticmethod
    def list_local():
        all_rule_files = []
        for filename in glob.glob('%s*.rules' % rule_source_directory):
53 54 55 56
            all_rule_files.append(filename)

        return all_rule_files

57
    def list_rules(self, filename):
58 59 60 61 62 63
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
64
            rule_info_record = {'rule': rule, 'metadata': None}
65 66
            if rule.find('msg:') != -1:
                # define basic record
67 68
                record = {'enabled': True, 'source': filename.split('/')[-1]}
                if rule.strip()[0] == '#':
69 70 71 72
                    record['enabled'] = False

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
73 74 75 76 77
                    fieldname = field[0:field.find(':')].strip()
                    fieldcontent = field[field.find(':') + 1:].strip()
                    if fieldname in self._rule_fields:
                        if fieldcontent[0] == '"':
                            content = fieldcontent[1:-1]
78
                        else:
79
                            content = fieldcontent
80

81
                        if fieldname in record:
82
                            # if same field repeats, put items in list
83 84 85
                            if type(record[fieldname]) != list:
                                record[fieldname] = [record[fieldname]]
                            record[fieldname].append(content)
86
                        else:
87
                            record[fieldname] = content
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

105
    def is_changed(self):
106 107 108 109 110
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
111
            all_rule_files = self.list_local()
112 113 114 115 116 117 118 119
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
120
                cur.execute('SELECT max(timestamp), max(files) FROM stats')
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
                results = cur.fetchall()
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1]:
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
138 139 140 141
        cur.execute('CREATE TABLE stats (timestamp number, files number)')
        cur.execute("""CREATE TABLE rules (sid number, msg TEXT, classtype TEXT,
                                           rev INTEGER, gid INTEGER,reference TEXT,
                                           enabled BOOLEAN,source TEXT)""")
142

143 144
        last_mtime = 0
        all_rule_files = self.list_local()
145 146 147 148 149
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
150
            for rule_info_record in self.list_rules(filename=filename):
151 152
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
153 154

            cur.executemany('insert into rules(%(fieldnames)s) '
155 156 157
                            'values (%(fieldvalues)s)' % {'fieldnames': (','.join(self._rule_fields)),
                                                          'fieldvalues': ':' + (',:'.join(self._rule_fields))}, rules)
        cur.execute('INSERT INTO stats (timestamp,files) VALUES (?,?) ', (last_mtime, len(all_rule_files)))
158 159
        db.commit()

160
    def search(self, limit, offset, filter_txt, sort_by):
161 162 163
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
164 165
        :param filter_txt: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
        :param sort_by: order by, list of fields and possible asc/desc parameter
166 167
        :return: dict
        """
168
        result = {'rows': []}
169 170 171 172 173 174 175 176
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
            sql = 'select * from rules '
            sql_filters = {}

177
            for filtertag in shlex.split(filter_txt):
178 179 180 181
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
182
                else:
183 184 185 186 187 188
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
                    if fieldname in self._rule_fields:
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
189
                            sql += 'cast(' + fieldname + " as text) like :" + fieldname + " "
190
                        else:
191
                            sql += 'cast(' + fieldname + " as text) like '%'|| :" + fieldname + " || '%' "
192 193 194 195 196 197 198
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
199
            sql_sort = []
200 201 202
            for sortField in sort_by.split(','):
                if sortField.split(' ')[0] in self._rule_fields:
                    if sortField.split(' ')[-1].lower() == 'desc':
203
                        sql_sort.append('%s desc' % sortField.split()[0])
204
                    else:
205
                        sql_sort.append('%s asc' % sortField.split()[0])
206

207
            # count total number of rows
208
            cur.execute('select count(*) from (%s) a' % sql, sql_filters)
209
            result['total_rows'] = cur.fetchall()[0][0]
210

211
            if len(sql_sort) > 0:
212
                sql += ' order by %s' % (','.join(sql_sort))
213

214
            if str(limit) != '0' and str(limit).isdigit():
215
                sql += ' limit %s' % limit
216
                if str(offset) != '0' and str(offset).isdigit():
217
                    sql += ' offset %s' % offset
218

219
            # fetch results
220
            cur.execute(sql, sql_filters)
221 222 223 224
            while True:
                row = cur.fetchone()
                if row is None:
                    break
225

226 227 228 229
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
230 231

        return result
232

233
    def list_class_types(self):
234 235 236 237
        """
        :return: list of installed classtypes
        """
        result = []
238 239 240
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
241
            cur.execute('SELECT DISTINCT classtype FROM rules')
242 243 244 245 246 247
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result