///###////////////////////////////////////////////////////////////////////////
//
// Burton Computer Corporation
// http://www.burton-computer.com
// $Id: SpamFilter.cc,v 1.41 2004/01/26 20:19:25 bburton Exp $
//
// Copyright (C) 2000 Burton Computer Corporation
// ALL RIGHTS RESERVED
//
// This program is open source software; you can redistribute it
// and/or modify it under the terms of the Q Public License (QPL)
// version 1.0. Use of this software in whole or in part, including
// linking it (modified or unmodified) into other programs is
// subject to the terms of the QPL.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// Q Public License for more details.
//
// You should have received a copy of the Q Public License
// along with this program; see the file LICENSE.txt.  If not, visit
// the Burton Computer Corporation or CoolDevTools web site
// QPL pages at:
//
//    http://www.burton-computer.com/qpl.html
//

#include <memory>
#include <stdexcept>
#include <algorithm>
#include <cmath>
#include "LockFile.h"
#include "TokenSelector.h"
#include "SpamFilter.h"

static const double MIN_PROB = 0.000001;
static const double MAX_PROB = 0.999999;
static const double NEARLY_ZERO = 0.0000001;
static const int MAX_COUNT_FOR_SORT = 20;
static const int COUNT_SORT_DIVISOR = 3;
static const double PROB_SORT_MULTIPLE = 100000.0;
static const double EXT_ARRAY_THRESHOLD = 0.4999;
static const string LOCK_FILENAME("lock");

typedef vector<Token *>::size_type vtokindex_t;

SpamFilter::SpamFilter()
  : m_termsForScore(27),
    m_defaultMinWordCount(5),
    m_maxWordRepeats(2),
    m_extendTopTerms(false),
    m_newWordScore(0.3),
    m_minDistanceForScore(0.0),
    m_minArraySize(0),
    m_waterCounts(false),
    m_defaultThreshold(0.6)
{
  setScoreMode(SCORE_NORMAL);
  m_tokenSelectors.push_back(new TokenSelector());
}

SpamFilter::~SpamFilter()
{
  close();
  clearTokenSelectors();
}

void SpamFilter::clearTokenSelectors()
{
  for (vector<TokenSelector *>::iterator i = m_tokenSelectors.begin(); i != m_tokenSelectors.end(); ++i) {
    delete (*i);
  }
  m_tokenSelectors.clear();
}

void SpamFilter::lock(const File &raw_db_file,
                      bool read_only)
{
  File db_file(FrequencyDB::removeTypePrefix(raw_db_file.getPath()));
  File lock_file(db_file.parent(), LOCK_FILENAME);
  m_lock.set(new LockFile(lock_file));
  m_lock->lock(read_only ? LockFD::SHARED_LOCK : LockFD::EXCLUSIVE_LOCK);
}

void SpamFilter::open(const File &db_file,
                      bool read_only)
{
  lock(db_file, read_only);
  if (!m_db.open(db_file.getPath(), read_only)) {
    throw runtime_error("unable to open words database");
  }
}

void SpamFilter::open(const File &shared_db_file,
                      const File &private_db_file,
                      bool read_only)
{
  lock(private_db_file, read_only);
  if (!m_db.open(shared_db_file.getPath(), private_db_file.getPath(), read_only)) {
    throw runtime_error("unable to open words databases");
  }
}

void SpamFilter::close(bool abandon_writes)
{
  if (m_db.isOpen()) {
    if (!abandon_writes) {
      m_db.flush();
    }
    m_db.close();
  }
  m_lock.clear();
}

double SpamFilter::computeRatio(double count,
                                double total_count)
{
  if (count == 0) {
    return NEARLY_ZERO;
  } else {
    double score = count / total_count;
    return max(MIN_PROB, min(MAX_PROB, score));
  }
}

double SpamFilter::scoreTerm(int good_count,
                             int spam_count,
                             int good_message_count,
                             int spam_message_count)
{
  good_count *= m_goodBias;

  int count = good_count + spam_count;

  double score;
  if (count >= m_minWordCount) {
    double good_ratio = computeRatio(good_count, good_message_count);
    double spam_ratio = computeRatio(spam_count, spam_message_count);
    score = computeRatio(spam_ratio, good_ratio + spam_ratio);
  } else {
    score = m_newWordScore;
  }

  return score;
}

// Returns "within document frequency" for sort.  Uses a limit on max
// value for sanity and also uses a ratio of the count to keep small
// differences (3 vs. 2) from having a disproportionate impact on the
// sort.
inline long wdf_for_sort(int count)
{
  return min(MAX_COUNT_FOR_SORT, count / COUNT_SORT_DIVISOR);
}

// Returns probability rounded to 5 decimal places to "band" results for sorting.
// This prevents tiny differences from having excessive impact on results.
inline double rounded_prob(double prob)
{
  return floor(prob * PROB_SORT_MULTIPLE);
}

void SpamFilter::scoreToken(Token *tok,
                            int good_message_count,
                            int spam_message_count)
{
  assert(tok->getCount() > 0);

  int good_count, spam_count;
  m_db.getWordCounts(tok->getWord(), good_count, spam_count);
  tok->setDBGoodCount(good_count);
  tok->setDBSpamCount(spam_count);
  if (tok->getDBTotalCount() < 0 && is_debug) {
    cerr << "WARNING: token " << tok->getWord() << " has negative count" << endl;
  }

  double score = scoreTerm(tok->getDBGoodCount(), tok->getDBSpamCount(), good_message_count, spam_message_count);
  assert(score >= MIN_PROB);
  assert(score <= MAX_PROB);
  tok->setScore(score);

  long sort_count = ((long)rounded_prob(abs(score - 0.5))) << 6;
  sort_count |= (wdf_for_sort(tok->getCount()) & 0x1f) << 1;
  if (score < 0.5) {
    sort_count |= 1;
  }
  assert(sort_count >= 0);
  tok->setSortCount(sort_count);

  double tie_break_ratio;
  if (score > 0.5) {
    tie_break_ratio = computeRatio(spam_count, spam_message_count);
  } else {
    tie_break_ratio = computeRatio(good_count, good_message_count);
  }
  tok->setTieBreakCount((long)(tie_break_ratio / MIN_PROB));
}

static int token_qsort_criterion(Token **t1p, Token **t2p)
{
  Token *t1 = *t1p;
  Token *t2 = *t2p;

  long diff = t2->getSortCount() - t1->getSortCount();
  if (diff != 0) {
    return diff;
  }

  diff = t2->getTieBreakCount() - t1->getTieBreakCount();
  if (diff != 0) {
    return diff;
  }

  // as a last resort sort alphabetically
  return t1->getWord().compare(t2->getWord());
}

void SpamFilter::scoreTokens(const Message &msg)
{
  int good_message_count, spam_message_count;
  m_db.getMessageCounts(good_message_count, spam_message_count);

  for (int i = 0; i < msg.getTokenCount(); ++i) {
    Token *tok = msg.getToken(i);
    assert(tok);
    scoreToken(tok, good_message_count, spam_message_count);
  }
}

void SpamFilter::getSortedTokens(const Message &msg,
                                 TokenSelector *selector,
                                 int max_tokens,
                                 vector<Token *> &tokens)
{
  selector->selectTokens(msg, tokens);
  if (tokens.size() == 0) {
    return;
  }

  // STL sort is bugged so we have to use qsort!
  vtokindex_t num_tokens = tokens.size();
  NewPtr<Token *> sorted(new Token *[num_tokens]);
  for (vtokindex_t i = 0; i < num_tokens; ++i) {
    sorted.get()[i] = tokens[i];
  }

  qsort(sorted.get(), num_tokens, sizeof(Token *),
        (int (*) (const void *, const void *))token_qsort_criterion);

  if (is_debug) {
    for (vtokindex_t i = 0; i < num_tokens; ++i) {
      Token *tok = sorted.get()[i];
      cerr << "SORTED " << tok->getWord()
           << " count " << tok->getCount()
           << " score " << tok->getScore()
           << " dist " << tok->getDistanceFromMean()
           << " good " << tok->getDBGoodCount()
           << " spam " << tok->getDBSpamCount()
           << " tiebreak " << tok->getTieBreakCount()
           << endl;
    }
  }

  tokens.clear();
  vtokindex_t max_size = (vtokindex_t)max_tokens;
  for (vtokindex_t i = 0; i < max_size && i < num_tokens; ++i) {
    Token *tok = sorted.get()[i];
    if (tok->getDistanceFromMean() >= m_minDistanceForScore || i < m_minArraySize) {
      tokens.push_back(tok);
    }
  }
}

void SpamFilter::computeScoreProducts(Message &msg,
                                      TokenSelector *selector,
                                      double &spamness,
                                      double &goodness,
                                      double &num_terms)
{
  vector<Token *> tokens;
  getSortedTokens(msg, selector, msg.getTokenCount(), tokens);

  msg.clearTopTokens();

  int max_word_repeats = m_maxWordRepeats;
  int max_terms = m_termsForScore;

  if (m_waterCounts) {
    max_terms = max(15, (int)tokens.size() / 25);
    max_word_repeats = max(1, max_terms / 15);
  }

  goodness = 1.0;
  spamness = 1.0;
  num_terms = 0;
  for (vector<Token *>::const_iterator i = tokens.begin(); i != tokens.end(); ++i) {
    Token *tok = *i;

    if ((num_terms >= max_terms) &&
        (!m_extendTopTerms || (tok->getDistanceFromMean() < EXT_ARRAY_THRESHOLD))) {
      break;
    }

    int count = tok->getCount();
    double score = tok->getScore();

    msg.addTopToken(tok);

    assert(count > 0);

    int times = min(max_word_repeats, count);
    if (is_debug) {
      cerr << "** TOKEN " << tok->getWord() << ": score " << score << ": times " << times << endl;
    }

    num_terms += times;
    while (times-- > 0) {
      spamness = spamness * score;
      goodness = goodness * (1.0 - score);
    }
  }
}

//
// Paul Graham's original formula as outlined in A Plan For Spam.
// Yields excellent results but nearly all messages score as 0 or 1
//
double SpamFilter::originalScoreMessage(Message &msg,
                                        TokenSelector *selector)
{
  double spamness, goodness, num_terms;
  computeScoreProducts(msg, selector, spamness, goodness, num_terms);

  double score = spamness / (spamness + goodness);
  if (is_nan(score)) {
    score = 0.5;
  }

  if (is_debug) {
    cerr << "** SPAMNESS " << spamness << ": GOODNESS " << goodness << ": SCORE "
         << score << endl;
  }

  return score;
}

//
// Robinson's simple algorithm.  Refer to his article for details:
// http://radio.weblogs.com/0101454/stories/2002/09/16/spamDetection.html
//
double SpamFilter::alt1ScoreMessage(Message &msg,
                                    TokenSelector *selector)
{
  double spamness, goodness, num_terms;
  computeScoreProducts(msg, selector, spamness, goodness, num_terms);

  double p = 1.0 - pow(goodness, 1.0 / num_terms);
  double q = 1.0 - pow(spamness, 1.0 / num_terms);
  double s = (p - q) / (p + q);
  s = (s + 1.0) / 2.0;

  if (is_debug) {
    cerr << "SCORE P " << p << ": Q " << q << ": SCORE " << s << endl;
  }

  return s;
}

//
// Paul's formula modified to use the nth root of the products.  Gives
// same results as Paul's formula but spreads out scores more evenly.
// Gives a better distribution than Robinson's simple formula.
//
double SpamFilter::normalScoreMessage(Message &msg,
                                      TokenSelector *selector)
{
  double spamness, goodness, num_terms;
  computeScoreProducts(msg, selector, spamness, goodness, num_terms);

  spamness = pow(spamness, 1.0 / num_terms);
  goodness = pow(goodness, 1.0 / num_terms);

  double score = spamness / (spamness + goodness);
  if (is_nan(score)) {
    score = 0.5;
  }

  if (is_debug) {
    cerr << "** SPAMNESS " << spamness << ": GOODNESS " << goodness << ": SCORE "
         << score << endl;
  }

  return score;
}

SpamFilter::Score SpamFilter::scoreMessage(Message &msg)
{
  scoreTokens(msg);

  double score_value = 0;
  for (vector<TokenSelector *>::const_iterator i = m_tokenSelectors.begin(); i != m_tokenSelectors.end(); ++i) {
    score_value += scoreMessage(msg, *i);
  }
  score_value = score_value / ((double)m_tokenSelectors.size());

  return Score(score_value, m_scoreMode, score_value >= getSpamThreshold());
}

double SpamFilter::scoreMessage(Message &msg,
                                TokenSelector *selector)
{
  double score_value;
  switch (m_scoreMode) {
  case SCORE_ALT1:
    score_value = alt1ScoreMessage(msg, selector);
    break;

  case SCORE_ORIGINAL:
    score_value = originalScoreMessage(msg, selector);
    break;

  default:
    assert(m_scoreMode == SCORE_NORMAL);
    score_value = normalScoreMessage(msg, selector);
    break;
  }

  return score_value;
}

SpamFilter::Score SpamFilter::scoreMessage(Message &msg,
                                           ScoreMode_t mode)
{
  if (m_scoreMode == mode) {
    return scoreMessage(msg);
  }

  ScoreMode_t old_mode = m_scoreMode;
  setScoreMode(mode);
  Score score = scoreMessage(msg);
  setScoreMode(old_mode);
  return score;
}

void SpamFilter::ensureGoodMessage(const Message &msg,
                                   bool force_update)
{
  bool is_spam = false;
  if (m_db.containsMessage(msg, is_spam) && is_spam) {
    // remove from database if previously classified as spam
    m_db.removeMessage(msg);
  }
  m_db.addMessage(msg, false, force_update);
}

void SpamFilter::ensureSpamMessage(const Message &msg,
                                   bool force_update)
{
  bool is_spam = false;
  if (m_db.containsMessage(msg, is_spam) && !is_spam) {
    // remove from database if previously classified as good
    m_db.removeMessage(msg);
  }
  m_db.addMessage(msg, true, force_update);
}
