rulecache.py 9.94 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
    Copyright (c) 2015 Ad Schellevis
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice,
     this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.

    THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
    AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
    AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
    OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
    POSSIBILITY OF SUCH DAMAGE.

    --------------------------------------------------------------------------------------
27

28 29
    shared module for suricata scripts, handles the installed rules cache for easy access
"""
30

31 32 33 34
import os
import os.path
import glob
import sqlite3
35
import shlex
36 37
from lib import rule_source_directory

38 39 40 41 42 43

class RuleCache(object):
    """
    """
    def __init__(self):
        # suricata rule settings, source directory and cache json file to use
44
        self.cachefile = '%srules.sqlite'%rule_source_directory
45
        self._rule_fields = ['sid','msg','classtype','rev','gid','source','enabled','reference']
46
        self._rule_defaults = {'classtype':'##none##'}
47 48 49

    def listLocal(self):
        all_rule_files=[]
50
        for filename in glob.glob('%s*.rules'%(rule_source_directory)):
51 52 53 54
            all_rule_files.append(filename)

        return all_rule_files

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    def listRules(self, filename):
        """ generator function to list rule file content including metadata
        :param filename:
        :return:
        """
        data = open(filename)
        for rule in data.read().split('\n'):
            rule_info_record = {'rule':rule, 'metadata':None}
            if rule.find('msg:') != -1:
                # define basic record
                record = {'enabled':True, 'source':filename.split('/')[-1]}
                if rule.strip()[0] =='#':
                    record['enabled'] = False

                rule_metadata = rule[rule.find('msg:'):-1]
                for field in rule_metadata.split(';'):
                    fieldName = field[0:field.find(':')].strip()
                    fieldContent = field[field.find(':')+1:].strip()
                    if fieldName in self._rule_fields:
                        if fieldContent[0] == '"':
                            content = fieldContent[1:-1]
                        else:
                            content = fieldContent

                        if fieldName in record:
                            # if same field repeats, put items in list
                            if type(record[fieldName]) != list:
                                record[fieldName] = [record[fieldName]]
                            record[fieldName].append(content)
                        else:
                            record[fieldName] = content

                for rule_field in self._rule_fields:
                    if rule_field not in record:
                        if rule_field in self._rule_defaults:
                            record[rule_field] = self._rule_defaults[rule_field]
                        else:
                            record[rule_field] = None

                # perform type conversions
                for fieldname in record:
                    if type(record[fieldname]) == list:
                        record[fieldname] = '\n'.join(record[fieldname])

                rule_info_record['metadata'] = record

            yield rule_info_record

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    def isChanged(self):
        """ check if rules on disk are probably different from rules in cache
        :return: boolean
        """
        if os.path.exists(self.cachefile):
            last_mtime = 0
            all_rule_files = self.listLocal()
            for filename in all_rule_files:
                file_mtime = os.stat(filename).st_mtime
                if file_mtime > last_mtime:
                    last_mtime = file_mtime

            try:
                db = sqlite3.connect(self.cachefile)
                cur = db.cursor()
                cur.execute('select max(timestamp), max(files) from stats')
                results = cur.fetchall()
                if last_mtime == results[0][0] and len(all_rule_files) == results[0][1]:
                    return False
            except sqlite3.DatabaseError:
                # if some reason the cache is unreadble, continue and report changed
                pass
        return True

    def create(self):
        """ create new cache
        :return: None
        """
        if os.path.exists(self.cachefile):
            os.remove(self.cachefile)

        db = sqlite3.connect(self.cachefile)
        cur = db.cursor()
        cur.execute('create table stats (timestamp number, files number)')
        cur.execute("""create table rules (sid number, msg text, classtype text,
138 139
                                           rev integer, gid integer,reference text,
                                           enabled boolean,source text)""")
140 141 142 143 144 145 146 147

        last_mtime=0
        all_rule_files = self.listLocal()
        for filename in all_rule_files:
            file_mtime = os.stat(filename).st_mtime
            if file_mtime > last_mtime:
                last_mtime = file_mtime
            rules = []
148 149 150
            for rule_info_record in self.listRules(filename=filename):
                if rule_info_record['metadata'] is not None:
                    rules.append(rule_info_record['metadata'])
151 152 153 154 155 156 157

            cur.executemany('insert into rules(%(fieldnames)s) '
                            'values (%(fieldvalues)s)'%{'fieldnames':(','.join(self._rule_fields)),
                                                        'fieldvalues':':'+(',:'.join(self._rule_fields))}, rules)
        cur.execute('insert into stats (timestamp,files) values (?,?) ',(last_mtime,len(all_rule_files)))
        db.commit()

158
    def search(self, limit, offset, filter, sort_by):
159 160 161
        """ search installed rules
        :param limit: limit number of rows
        :param offset: limit offset
162
        :param filter: text to search, used format fieldname1,fieldname2/searchphrase include % to match on a part
163 164 165 166
        :param sort: order by, list of fields and possible asc/desc parameter
        :return: dict
        """
        result = {'rows':[]}
167 168 169 170 171 172 173 174 175 176 177 178 179
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()

            # construct query including filters
            sql = 'select * from rules '
            sql_filters = {}

            for filtertag in shlex.split(filter):
                fieldnames = filtertag.split('/')[0]
                searchcontent = '/'.join(filtertag.split('/')[1:])
                if len(sql_filters) > 0:
                    sql += ' and ( '
180
                else:
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
                    sql += ' where ( '
                for fieldname in map(lambda x: x.lower().strip(), fieldnames.split(',')):
                    if fieldname in self._rule_fields:
                        if fieldname != fieldnames.split(',')[0].strip():
                            sql += ' or '
                        if searchcontent.find('*') == -1:
                            sql += 'cast('+fieldname + " as text) like :"+fieldname+" "
                        else:
                            sql += 'cast('+fieldname + " as text) like '%'|| :"+fieldname+" || '%' "
                        sql_filters[fieldname] = searchcontent.replace('*', '')
                    else:
                        # not a valid fieldname, add a tag to make sure our sql statement is valid
                        sql += ' 1 = 1 '
                sql += ' ) '

            # apply sort order (if any)
            sql_sort =[]
            for sortField in sort_by.split(','):
                if sortField.split(' ')[0] in self._rule_fields:
                    if sortField.split(' ')[-1].lower() == 'desc':
                        sql_sort.append('%s desc'%sortField.split()[0])
                    else:
                        sql_sort.append('%s asc'%sortField.split()[0])
204

205 206 207
            # count total number of rows
            cur.execute('select count(*) from (%s) a'%sql, sql_filters)
            result['total_rows'] = cur.fetchall()[0][0]
208

209 210
            if len(sql_sort) > 0:
                sql += ' order by %s'%(','.join(sql_sort))
211

212 213 214 215
            if str(limit) != '0' and str(limit).isdigit():
                sql += ' limit %s'%(limit)
                if str(offset) != '0' and str(offset).isdigit():
                    sql += ' offset %s'%(offset)
216

217 218 219 220 221 222
            # fetch results
            cur.execute(sql,sql_filters)
            while True:
                row = cur.fetchone()
                if row is None:
                    break
223

224 225 226 227
                record = {}
                for fieldNum in range(len(cur.description)):
                    record[cur.description[fieldNum][0]] = row[fieldNum]
                result['rows'].append(record)
228 229

        return result
230 231 232 233 234 235

    def listClassTypes(self):
        """
        :return: list of installed classtypes
        """
        result = []
236 237 238 239 240 241 242 243 244 245
        if os.path.exists(self.cachefile):
            db = sqlite3.connect(self.cachefile)
            cur = db.cursor()
            cur.execute('select distinct classtype from rules')
            for record in cur.fetchall():
                result.append(record[0])

            return sorted(result)
        else:
            return result