--[[
Rspamd filter priority:
prefilter: MCP(10), extract MCP(9), AntiVirus whitelist(6), AntiVirus(5), Blacklist/Whitelist(2), cache(1)
filter: spam, SpamAssassin rules
postfilter: header tagging + cache(10)
]] --

local rspamd_logger = require "rspamd_logger"

local local_conf = rspamd_paths['CONFDIR']
local cache = dofile(local_conf .. '/cache.lua')

local SPAM_STATUS_HEADER = 'X-Synology-Spam-Status'
local SPAM_FLAG_HEADER = 'X-Synology-Spam-Flag'
local VIRUS_HEADER = 'X-Synology-Virus-Status'
local MCP_HEADER = 'X-Synology-MCP-Status'

local MCP_SYMBOL_PREFIX = 'SYNO_MCP_'
local MCP_SCORE_SYMBOL = 'SYNO_MCP_SCORE'
local ANTIVIRUS_SYMBOLS = {'CLAMAV_VIRUS', 'SYNOAV_MCAFEE_VIRUS'}
local ANTIVIRUS_WHITELIST_SYMBOLS = {
	'ANTIVIRUS_WHITELIST_SENDER_IP',
	'ANTIVIRUS_WHITELIST_SENDER_DOMAIN',
	'ANTIVIRUS_WHITELIST_SENDER_EMAIL',
	'ANTIVIRUS_WHITELIST_RECIPIENT_EMAIL'
}

local SPAM_BLACKLIST_PREFIX = 'BLACKLIST_'
local SPAM_WHITELIST_PREFIX = 'WHITELIST_'

local function is_prefix(str, prefix)
	return string.sub(str, 1, string.len(prefix)) == prefix
end

local function is_mcp_symbol(symbol)
	return is_prefix(symbol['name'], MCP_SYMBOL_PREFIX)
end

local function is_antivirus_symbol(symbol)
	for i, antivirus_symbol in ipairs(ANTIVIRUS_SYMBOLS) do
		if antivirus_symbol == symbol['name'] then
			return true
		end
	end

	return false
end

local function is_antivirus_whitelist_symbol(symbol)
	for i, antivirus_whitelist_symbol in ipairs(ANTIVIRUS_WHITELIST_SYMBOLS) do
		if antivirus_whitelist_symbol == symbol['name'] then
			return true
		end
	end

	return false
end

local function is_spam_symbol(symbol)
	return not is_mcp_symbol(symbol) and
		   not is_antivirus_symbol(symbol) and
		   not is_antivirus_whitelist_symbol(symbol)
end

local function is_whitelist_symbol(symbol)
	return is_prefix(symbol['name'], SPAM_WHITELIST_PREFIX)
end

local function is_blacklist_symbol(symbol)
	return is_prefix(symbol['name'], SPAM_BLACKLIST_PREFIX)
end

local function spam_check(task, add, del)
	local spam_opt = rspamd_config:get_all_opt('spam')
	local spam_enable = spam_opt['enable']
	local spam_learn_enable = spam_opt['learn_enable']

	if spam_enable == 'no' then
		return
	end

	local action = task:get_metric_action('default')
	local score = task:get_metric_score('default')

	local symbols = task:get_symbols_all()
	-- symbol: {[group] = group, [options] = {[1] = option1, [2] = option2}, [name] = name, [score] = score}
	-- whitelist
	for i, symbol in ipairs(symbols) do
		if is_whitelist_symbol(symbol) then
			add[SPAM_FLAG_HEADER] = 'no'
			add[SPAM_STATUS_HEADER] = 'score=0, required ' .. score[2] .. ', ' .. symbol['name'] .. ' 0'
			rspamd_logger.infox(task, "Hit whitelist; symbol %s", symbol)
			return
		end
	end
	-- blacklist
	for i, symbol in ipairs(symbols) do
		if is_blacklist_symbol(symbol) then
			add[SPAM_FLAG_HEADER] = 'yes'
			add[SPAM_STATUS_HEADER] = 'score=' .. score[2] .. ', required ' .. score[2] .. ', ' .. symbol['name'] .. ' ' .. score[2]
			rspamd_logger.infox(task, "Hit blacklist; symbol %s", symbol)
			return
		end
	end

	-- score = {[1] = current score, [2] = required score }
	local spam_status
	local cache_hit = false
	if cache.enabled then
		local cache_symbol = task:get_symbol(cache.settings.symbol)
		if cache_symbol ~= nil then
			rspamd_logger.infox(task, "Cache hit; getting spam status from cache")
			cache_hit = true

			-- split cache_symbol into flag and status with semicolon
			local cache_spam_result = cache_symbol[1]['options'][1]
			local sep_index = string.find(cache_spam_result, ';')
			if sep_index then
				add[SPAM_FLAG_HEADER] = string.sub(cache_spam_result, 1, sep_index - 1)
				spam_status = string.sub(cache_spam_result, sep_index + 1)
			end
		end
	end

	if spam_status == nil then
		if cache.enabled then
			rspamd_logger.infox(task, "Cache miss; calculate spam status")
		end
		spam_status = 'score=' .. score[1] .. ', required ' .. score[2]
		if spam_learn_enable == 'yes' then
			local spam_learn_threshold_spam = spam_opt['learn_threshold_spam']
			local spam_learn_threshold_non_spam = spam_opt['learn_threshold_non_spam']

			if score[1] >= spam_learn_threshold_spam then
				spam_status = spam_status .. ', autolearn=spam'
			elseif score[1] <= spam_learn_threshold_non_spam then
				spam_status = spam_status .. ', autolearn=ham'
			end
		end

		for i, symbol in ipairs(symbols) do
			if is_spam_symbol(symbol) then
				spam_status = spam_status .. ', ' .. symbol['name'] .. ' ' .. symbol['score']
			end
		end

		if action == 'add header' then
			add[SPAM_FLAG_HEADER] = 'yes'
		else
			add[SPAM_FLAG_HEADER] = 'no'
		end
	end

	add[SPAM_STATUS_HEADER] = spam_status

	rspamd_logger.infox(task, "Spam Flag: %1", add[SPAM_FLAG_HEADER])
	rspamd_logger.infox(task, "Spam Status: %1", add[SPAM_STATUS_HEADER])
	if add[SPAM_FLAG_HEADER] == 'yes' then
		rspamd_logger.errx(task, "msgid=<%1>: Spam, %2", task:get_message_id(), add[SPAM_STATUS_HEADER])
	end

	-- add spam result to cache if not in cache
	if cache.enabled and not cache_hit then
		-- join flag and status into cache_symbol with semicolon
		cache.set(task, add[SPAM_FLAG_HEADER] .. ';' .. add[SPAM_STATUS_HEADER])
	end
end

local function antivirus_check(task, add, del)
	local antivirus_opt = rspamd_config:get_all_opt('antivirus')
	local antivirus_enable = antivirus_opt['enable']

	if antivirus_enable == 'no' then
		return
	end

	local has_virus = false
	local virus_names

	for i, virus_symbol in ipairs(ANTIVIRUS_SYMBOLS) do
		local virus = task:get_symbol(virus_symbol)
		if virus then
			has_virus = true
			virus_names = virus[1]['options']
			break
		end
	end

	if has_virus then
		virus_names = table.concat(virus_names, ', ')
		add[VIRUS_HEADER] = 'yes, ' .. virus_names
		rspamd_logger.errx(task, "msgid=<%1>, Virus, %2", task:get_message_id(), virus_names)
	else
		add[VIRUS_HEADER] = 'no'
	end
end

local function mcp_check(task, add, del)
	local mcp_opt = rspamd_config:get_all_opt('mcp')
	local mcp_enable = mcp_opt['enable']

	if mcp_enable == 'no' then
		return
	end

	local mcp_score_required = mcp_opt['actions']['add_header']
	local mcp_score = 0
	local mcp_symbols = {}
	local symbols = task:get_symbols_all()

	for i, symbol in ipairs(symbols) do
		if is_mcp_symbol(symbol) then
			if symbol['name'] == MCP_SCORE_SYMBOL then
				mcp_score = tonumber(symbol['options'][1])
			else
				mcp_symbols[#mcp_symbols + 1] = symbol['name']
			end
		end
	end

	if mcp_score >= mcp_score_required then
		local mcp_header = 'yes, score=' .. mcp_score .. ', required ' .. mcp_score_required
		local mcp_symbol_string = table.concat(mcp_symbols, ', ')

		add[MCP_HEADER] = mcp_header .. ', ' .. mcp_symbol_string
		rspamd_logger.errx(task, "msgid=<%1>, MCP, %2", task:get_message_id(), mcp_symbol_string)
	else
		add[MCP_HEADER] = 'no'
	end
end

-- Multiple Headers
rspamd_config:register_symbol({
	name = 'RMILTER_MCP_HEADER',
	type = 'prefilter',
	priority = 9,
	callback = function(task)
		-- Set total MCP score to MCP_SCORE_SYMBOL's option
		-- Remove total MCP score from default metric
		local mcp_opt = rspamd_config:get_all_opt('mcp')
		local mcp_enable = mcp_opt['enable']

		if mcp_enable == 'no' then
			return
		end

		local mcp_score = 0
		local mcp_symbols = {}
		local symbols = task:get_symbols_all()

		for i, symbol in ipairs(symbols) do
			if is_mcp_symbol(symbol) then
				mcp_score = mcp_score + symbol['score']
			end
		end

		task:insert_result(MCP_SCORE_SYMBOL, 0.0, tostring(mcp_score))

		-- remove mcp score
		local score = task:get_metric_score('default') -- score = {[1] = current score, [2] = required score }
		task:set_metric_score('default', score[1] - mcp_score)
	end
})

-- Multiple Headers
rspamd_config:register_symbol({
	name = 'RMILTER_HEADERS',
	type = 'postfilter',
	priority = 10,
	callback = function(task)
		local add = {}
		local del = {
			['X-Spam'] = 1,
			['X-Virus'] = 1,
		}

		-- mcp
		mcp_check(task, add, del)

		-- spam
		spam_check(task, add, del)

		-- virus
		antivirus_check(task, add, del)

		task:set_metric_action('default', 'no action')
		task:set_rmilter_reply({
			add_headers = add,
			remove_headers = del,
		})
	end
})
