#!/usr/bin/python
"""
Apparmor Profile Generator
by PG @ Synology
Released Under the MIT License
"""

config = {}
config['webapi_only'] = False
config['disable_filter'] = False
config['show_raw_log'] = False
config['show_untracked_log'] = False
config['use_time_filter'] = False
config['ignore_time_parse_failed'] = False
config['stime'] = '01/01-00:00:00'
config['etime'] = '12/31-23:59:59'
config['log_path'] = '/var/log/kern.log'
config['filter_list'] = ['operation="profile_remove"', 'operation="profile_load"']
config['untracked_log_category'] = ['change_hat', 'DefaultHat', 'callbacks suppressed']
config['untracked_log_size'] = '30'
config['mask_mapping'] = {
	'c': 'rwk',
	'd': 'rwk'
}
config['mask_order'] = ['m', 'r', 'w', 'k', 'l', 'x']

import re, fileinput, sys, time, datetime

filter_list_p = []
list_all_p = '([^\"]+)' + '"'
profile_base_p = 'profile="'
# profile_base_p + list_all_p will be used to figure apparmor log out.

def init():	
	global config, filter_list_p, list_all_p, profile_base_p
	for filter in config['filter_list']:
		filter_list_p.append(re.compile(filter))
	if config['webapi_only']:
		list_all_p = '([A-Za-z0-9\.]+)' + '(/[^\"]*)?"'
		profile_base_p = 'profile="/usr/syno/synoman/webapi/entry.cgi//'


def refine_log(data):
	ret = {}

	def refine_log_unit(record):
		mask_map = {}
		ret = ''
		for mask in record:
			for i in mask:
				if i not in config['mask_mapping']:
					mask_map[i] = 1
				else:
					for j in config['mask_mapping'][i]:
						mask_map[j] = 1

		order = config['mask_order']
		for mask in sorted(mask_map, key=lambda x: order.index(x) if x in order else -1):
			if mask == 'x':
				ret = ret + 'ix'
			else:
				ret = ret + mask
		return ret

	for path in data:
		ret[path] = refine_log_unit(data[path])

	return ret

def analyze(target, lines):
	global config, filter_list_p, list_all_p, profile_base_p
	rule_p = re.compile('name="([^\"]+)".+requested_mask="([A-Za-z]+)"')
	capability_p = re.compile('capname="([^\"]+)"')
	network_p = re.compile('family="([a-z]+)" sock_type="([a-z]+)"')
	#target = 'SYNO.Core.Security.DSM'
	
	target_p = re.compile(profile_base_p + re.escape(target) + '(/[^\"]*)?"')
	target_grep_p = re.compile(re.escape(target))
	data = {}
	capability_data = {}
	network_data = {}
	parse_fail = []
	cnt = 0
	ret_flag = 0
	hit_cnt = 0

	for line in lines:
		cnt += 1

		if not target_grep_p.search(line):
			continue
		ret_flag = 1
		result = rule_p.search(line)
		capability_result = capability_p.search(line)
		network_result = network_p.search(line)
		if not (result or capability_result or network_result):
			parse_fail.append(line)
			continue

		hit_cnt += 1

		if result:
			path = result.group(1)
			requested_mask = result.group(2)
			if path not in data:
				data[path] = {}
			if requested_mask not in data[path]:
				data[path][requested_mask] = {}
				data[path][requested_mask]['count'] = 0;
				data[path][requested_mask]['cmd'] = []; 
			data[path][requested_mask]['count'] += 1
			data[path][requested_mask]['cmd'].append(line)
		elif capability_result:
			capibility_name = capability_result.group(1)
			if capibility_name not in capability_data:
				capability_data[capibility_name] = [];
			capability_data[capibility_name].append(line)
		elif network_result:
			family = network_result.group(1)
			sock_type = network_result.group(2)
			key = (family, sock_type)
			if key not in network_data:
				network_data[key] = []
			network_data[key].append(line)
		else :
			print('SHOUD NOT BE HERE, call PG plz')

	print('\033[1;33m{0}  {1} matched \033[0m'.format(target, hit_cnt, cnt))

	refined_data = refine_log(data)

	if config['show_raw_log']:
		for path in sorted(data):
			for mask in data[path]:
				print('\033[1;36m{0:60} {1}\033[0m'.format(path, mask))
				for log in data[path][mask]['cmd']:
					sys.stdout.write(log)
		for name in sorted(capability_data):
			print('\033[1;36mcapability {0},\033[0m'.format(name))
			for log in capability_data[name]:
				sys.stdout.write(log)
		for key in sorted(network_data):
			print('\033[1;36mnetwork    {0:6}{1},\033[0m'.format(key[0], key[1]))
			for log in network_data[key]:
				sys.stdout.write(log)
	else:
		for path in sorted(refined_data):
			print('{0:60} {1},'.format(path, refined_data[path]))
		for name in sorted(capability_data):
			print('capability {0},'.format(name))
		for key in sorted(network_data):
			print('network    {0:6}{1},'.format(key[0], key[1]))

	if len(parse_fail) > 0:
		print('\033[0;31mParse Fail List\033[0m')
		for line in parse_fail:
			print('\033[0;31m {0} \033[0m'.format(line))
	return ret_flag

def list_profiles(lines):
	global config, filter_list_p, list_all_p, profile_base_p
	data = {}
	ret = []
	target_p = re.compile(profile_base_p + list_all_p)
	for line in lines:
		result = target_p.search(line)
		if not result:
			continue
		data[result.group(1)] = 1
	for webapi in sorted(data):
		ret.append(webapi)
	return ret
		#print(result.group(2))

def list_available_config():
	global config
	ret = []
	for i in config:
		if type(config[i]) == bool or type(config[i]) == str:
			ret.append(i)
	return ret

def parse_argv_and_apply():
	global config
	config_list = list_available_config()

	config_p = re.compile('--([A-Za-z_]+)')
	for i in sys.argv:
		result = config_p.search(i)
		if not result:
			continue
		name = result.group(1)
		if name in config_list and type(config[name]) == bool:
			config[name] = True

	config_p = re.compile('--([A-Za-z_]+)=(.+)')
	for i in sys.argv:
		result = config_p.search(i)
		if not result:
			continue
		name = result.group(1)
		value = result.group(2)
		if name in config_list:
			config[name] = value

def read_log():
	lines = []
	filter_cnt = 0
	time_p = re.compile('^(\w+ +\d+ +\d+:\d+:\d+)')
	
	try:
		stime = datetime.datetime.strptime(config['stime'], '%m/%d-%H:%M:%S')
	except:
		stime = datetime.datetime.strptime(config['stime'], '%m/%d-%H')
	try:
		etime = datetime.datetime.strptime(config['etime'], '%m/%d-%H:%M:%S')
	except:
		etime = datetime.datetime.strptime(config['etime'], '%m/%d-%H')

	for line in fileinput.input(config['log_path']):
		if not config['disable_filter']:
			if chk_filter(line):
				filter_cnt += 1
				continue
		if config['use_time_filter']:
			result = time_p.search(line)
			if result:
				log_time = datetime.datetime.strptime(result.group(1), '%b %d %H:%M:%S')
				if log_time < stime or log_time > etime:
					filter_cnt += 1
					continue
			else:
				if not config['ignore_time_parse_failed']:
					print('\033[0;31mParse Time failed: {0}\033[0m'.format(line))

		lines.append(line)

	return lines, filter_cnt

def show_log_only(lines):
	for line in lines:
		sys.stdout.write(line)
	return len(lines) > 0

def list_unknown_log(lines):
	global config, list_all_p, profile_base_p
	target_p = re.compile(profile_base_p + list_all_p)
	ret = []

	for line in lines:
		result = target_p.search(line)
		if result:
			continue
		ret.append(line)
	return ret

def show_unknown_log(lines):
	if len(lines) == 0:
		return
	global config, list_all_p, profile_base_p
	category_list = config['untracked_log_category']
	logs = {}
	unknown_logs = []
	show_size = int(config['untracked_log_size'])

	print('\033[0;32mSome logs are not tracked, use --show_untracked_log to check it\033[0m')
	for line in lines:
		unknown_flag = True
		for category in category_list:
			if category in line:
				unknown_flag = False
				if category not in logs:
					logs[category] = []
				logs[category].append(line)
		if unknown_flag:
			unknown_logs.append(line)

	for category in logs:
		print("[\033[1;32m{0}\033[0m] : \033[1;31m{1}\033[0m lines".format(category, len(logs[category])))
		if config['show_untracked_log']:
			counter = 0;
			for line in reversed(logs[category]):
				counter += 1
				if counter > show_size:
					break
				sys.stdout.write(line)

	print("[\033[1;35mUnknown\033[0m] : \033[1;31m{0}\033[0m lines".format(len(unknown_logs)))
	if config['show_untracked_log']:
		counter = 0;
		for line in reversed(unknown_logs):
			counter += 1
			if counter > show_size:
				break
			sys.stdout.write(line)

###############################################################################################

if __name__ == '__main__':
	lines = []
	filter_cnt = 0

	if len(sys.argv) < 2:
		print('Usage : \n {0} --all \n {0} --list \n {0} --log \n {0} webapi'.format(__file__))
		print('Available configs:')
		for i in list_available_config():
			if type(config[i]) == bool:
				print('\033[0;36m--{0}\033[0m'.format(i))
			elif type(config[i]) == str:
				print('\033[0;32m--{0}={1}\033[0m'.format(i, config[i]))
		exit()

	parse_argv_and_apply()
	init()

	def chk_filter(line):
		for p in filter_list_p:
			if p.search(line):
				return True
		return False
	if config['use_time_filter']:
		print('Parsing log time')
	lines, filter_cnt = read_log()
	print('\033[0;32mlog size: {0} lines, {1} were removed by filter\033[0m'.format(len(lines), filter_cnt))
	
	if sys.argv[1] == '--list':
		result = list_profiles(lines)
		exit_flag = 0
		if len(result) > 0:
			exit_flag = 1
		print(result)
		exit(exit_flag)

	if sys.argv[1] == '--all':
		exit_flag = 0
		profile_list = list_profiles(lines)
		for profile in profile_list:
			tmp_flag = analyze(profile, lines)
			if tmp_flag > 0:
				exit_flag = 1
		unknown_logs = list_unknown_log(lines)
		show_unknown_log(unknown_logs)
		print("\033[0;32m============================================================\033[0m")
		exit(exit_flag)

	if sys.argv[1] == '--log':
		exit_flag = 0
		exit(show_log_only(lines))

	exit_flag = analyze(sys.argv[1], lines)
	exit(exit_flag)
