#include <stdio.h>
#include <stdarg.h>
#include <string.h>
#include <setjmp.h>
#include "pprint.h"
#include "symbol.h"
#include "formula.h"

// Pretty printing parameters

#define INDENT 2
#define MARGIN 72

// Error handling

static jmp_buf env;

// Call leave to throw an exception

static void
leave()
{
  longjmp(env, 1);
}

static void
fatal(const char *fmt, ...)
{
  va_list ap;
  va_start(ap, fmt);
  vfprintf (stderr, fmt, ap);
  va_end(ap);
  leave();
}

// Simplifying formula constructors

static tran_t
con_tran_not(tran_t tran)
{
  switch (tran_type(tran)) {
  case TRAN_TRUE:
    return mk_tran_false();
  case TRAN_FALSE:
    return mk_tran_true();
  default:
    return mk_tran_not(tran);
  }
}

static tran_t
con_tran_and(tran_t arg1, tran_t arg2)
{
  switch (tran_type(arg1)) {
  case TRAN_TRUE:
    return arg2;
  case TRAN_FALSE:
    return mk_tran_false();
  default:
    switch (tran_type(arg2)) {
    case TRAN_TRUE:
      return arg1;
    case TRAN_FALSE:
      return mk_tran_false();
    default:
      return mk_tran_and(arg1, arg2);
    }
  }
}

static tran_t
con_tran_or(tran_t arg1, tran_t arg2)
{
  switch (tran_type(arg1)) {
  case TRAN_TRUE:
    return mk_tran_true();
  case TRAN_FALSE:
    return arg2;
  default:
    switch (tran_type(arg2)) {
    case TRAN_TRUE:
      return mk_tran_true();
    case TRAN_FALSE:
      return arg1;
    default:
      return mk_tran_or(arg1, arg2);
    }
  }
}

static tran_t
con_tran_imply(tran_t arg1, tran_t arg2)
{
  switch (tran_type(arg1)) {
  case TRAN_TRUE:
    return arg2;
  case TRAN_FALSE:
    return mk_tran_true();
  default:
    switch (tran_type(arg2)) {
    case TRAN_TRUE:
      return mk_tran_true();
    case TRAN_FALSE:
      return mk_tran_not(arg1);
    default:
      return mk_tran_imply(arg1, arg2);
    }
  }
}

// Pretty printing driver

static int
pp(pretty_t p)
{
  return pprint(stdout, p, MARGIN);
}

#define CLASS_VAR "c"
#define PERMISSION_VAR "p"
#define TYPE_VAR "t"
#define ROLE_VAR "r"
#define USER_VAR "u"
#define OKAY_VAR "k"
#define NEXT_PREFIX "next("
#define NEXT_SUFFIX ")"
#define NEXT_TYPE_VAR NEXT_PREFIX TYPE_VAR NEXT_SUFFIX
#define NEXT_ROLE_VAR NEXT_PREFIX ROLE_VAR NEXT_SUFFIX
#define NEXT_USER_VAR NEXT_PREFIX USER_VAR NEXT_SUFFIX
#define NEXT_OKAY_VAR NEXT_PREFIX OKAY_VAR NEXT_SUFFIX

#define SMV_TRUE "TRUE"
#define SMV_FALSE "FALSE"
#define SMV_EQ "="
#define SMV_NEQ "!="
#define SMV_IN "in"
#define SMV_NOT "!"
#define SMV_AND "&"
#define SMV_OR "|"
#define SMV_IFF "<->"
#define SMV_IMPLY "->"
#define SMV_AX "AX"
#define SMV_NEXT "EX"
#define SMV_UNTIL_OPEN "E["
#define SMV_UNTIL_MIDDLE "U"
#define SMV_UNTIL_CLOSE "]"

// Pretty printing routes used to generate SMV

static pretty_t
pnext(const char *var, pretty_t next)
{
  return pstring(NEXT_PREFIX, pstring(var, pstring(NEXT_SUFFIX, next)));
}

// Type of a symbol printer
typedef pretty_t (*psym_t)(symbol_t, pretty_t);

// Print a list of symbols using set notation
static pretty_t
pset_aux(psym_t psym, symbol_t sym, symbol_list_t list)
{
  pretty_t next;
  if (list) {
    next = pset_aux(psym, symbol_list_head(list),
		    symbol_list_tail(list));
    next = pbreak(1, next);
    next = pstring(",", next);
  }
  else
    next = pstring("}", PPRINT_NULL);
  return psym(sym, next);
}

static pretty_t
pset(psym_t psym, symbol_list_t list, pretty_t next)
{
  if (!list)
    return pstring("{}", next);
  else {
    pretty_t p = pset_aux(psym, symbol_list_head(list),
			  symbol_list_tail(list));
    return pblock(1, pstring("{", p), next);
  }
}

// Symbol printers

static int
has_suffix(const char *string, const char *suffix)
{
  int glen = strlen(string);
  int xlen = strlen(suffix);
  if (xlen > glen)
    return 0;
  else
    return !strcmp(string + (glen - xlen), suffix);
}

static pretty_t
anno_cls(symbol_t sym, pretty_t next)
{
  const char *name = symbol_name(sym);
  return pstring(name, pstring("_c", next));
}

static pretty_t
anno_perm(symbol_t sym, pretty_t next)
{
  const char *name = symbol_name(sym);
  return pstring(name, pstring("_p", next));
}

static pretty_t
anno_type(symbol_t sym, pretty_t next)
{
  const char *name = symbol_name(sym);
  if (has_suffix(name, "_t"))
    return pstring(name, next);
  fatal("bad type name: %s\n", name);
  return PPRINT_NULL;
}

static pretty_t
anno_role(symbol_t sym, pretty_t next)
{
  const char *name = symbol_name(sym);
  if (has_suffix(name, "_r"))
    return pstring(name, next);
  fatal("bad role name: %s\n", name);
  return PPRINT_NULL;
}

static pretty_t
anno_user(symbol_t sym, pretty_t next)
{
  const char *name = symbol_name(sym);
  if (has_suffix(name, "_u"))
    return pstring(name, next);
  else				/* This is dangerous! */
    return pstring(name, pstring("_u", next));
}

static pretty_t
pin(psym_t psym, const char *var, symbol_list_t list, int neg, pretty_t next)
{
  if (!list)
    return neg ? pstring(SMV_TRUE, next) : pstring(SMV_FALSE, next);
  else if (!symbol_list_tail(list)) {
    next = psym(symbol_list_head(list), next);
    if (neg)
      next = pstring(" " SMV_NEQ " ", next);
    else
      next = pstring(" " SMV_EQ " ", next);
    return pstring(var, next);
  }
  else {
    if (neg)
      next = pstring(")", next);
    next = pset(psym, list, next);
    next = pstring(" " SMV_IN " ", next);
    next = pstring(var, next);
    if (neg)
      next = pstring(SMV_NOT "(", next);
    return next;
  }
}

static pretty_t
psame(const char *var, int neg, pretty_t next)
{
  next = pnext(var, next);
  if (neg)
    next = pstring(" " SMV_NEQ " ", next);
  else
    next = pstring(" " SMV_EQ " ", next);
  return pstring(var, next);
}

typedef enum {MIN_PREC, IMPLY, IFF, OR, AND, NOT, NEXT, MAX_PREC} prec_t;

static prec_t
prec(tran_t tran)
{
  switch (tran_type(tran)) {
  case TRAN_NEXT:
  case TRAN_UNTIL:
    return NEXT;
  case TRAN_NOT:
    return NOT;
  case TRAN_AND:
    return AND;
  case TRAN_OR:
    return OR;
  case TRAN_IFF:
    return IFF;
  case TRAN_IMPLY:
    return IMPLY;
  default:
    return MAX_PREC;
  }
}

typedef enum {ASSOC, LEFT, RIGHT, NON_ASSOC} assoc_t;

static assoc_t
assoc(tran_t tran)
{
  switch (tran_type(tran)) {
  case TRAN_AND:
  case TRAN_OR:
    return ASSOC;
  case TRAN_IMPLY:
    return RIGHT;
  case TRAN_IFF:
    return LEFT;
  default:
    return NON_ASSOC;
  }
}

static prec_t
arg_prec(assoc_t side, assoc_t assoc, prec_t prec)
{
  if (side == assoc || assoc == ASSOC)
    return prec;
  else
    return prec + 1;
}

static pretty_t
ptran(tran_t tran, int neg, int level, pretty_t next);

static pretty_t
pbinary(tran_t tran, prec_t prec, const char *op, pretty_t next)
{
  assoc_t tran_assoc = assoc(tran);
  next = ptran(tran_arg2(tran), 0, arg_prec(RIGHT, tran_assoc, prec), next);
  next = pstring(op, next);
  next = pbreak(1, next);
  next = ptran(tran_arg1(tran), 0, arg_prec(LEFT, tran_assoc, prec), next);
  return next;
}

static pretty_t
paxop(tran_t tran, int neg, pretty_t next)
{
  pretty_t p = ptran(tran, 0, NEXT, PPRINT_NULL);
  if (neg)
    p = pstring(SMV_NOT SMV_AX " ", p);
  else
    p = pstring(SMV_AX " ", p);
  return pblock(INDENT, p, next);
}

static pretty_t
pnextop(tran_t tran, int neg, pretty_t next)
{
  pretty_t p = ptran(tran, 0, NEXT, PPRINT_NULL);
  if (neg)
    p = pstring(SMV_NOT SMV_NEXT " ", p);
  else
    p = pstring(SMV_NEXT " ", p);
  return pblock(INDENT, p, next);
}

static pretty_t
puntil(tran_t arg1, tran_t arg2, int neg, pretty_t next)
{
  pretty_t p = pstring(SMV_UNTIL_CLOSE, PPRINT_NULL);
  p = ptran(arg2, 0, MIN_PREC, p);
  p = pstring(SMV_UNTIL_MIDDLE " ", p);
  p = pblock(INDENT, p, PPRINT_NULL);
  p = pbreak(1, p);
  p = ptran(arg1, 0, MIN_PREC, p);
  if (neg)
    p = pstring(SMV_NOT SMV_UNTIL_OPEN, p);
  else
    p = pstring(SMV_UNTIL_OPEN, p);
  return pblock(2*INDENT, p, next);
}

static pretty_t
ptran(tran_t tran, int neg, int level, pretty_t next)
{
  int tran_prec = prec(tran);
  if (level > tran_prec) {
    pretty_t p = pstring(" )", PPRINT_NULL);
    p = ptran(tran, 0, MIN_PREC, p);
    if (neg)
      p = pstring(SMV_NOT "(", p);
    else
      p = pstring("( ", p);
    return pblock(INDENT, p, next);
  }
  else {
    switch (tran_type(tran)) {
    case TRAN_TRUE:
      return neg ? pstring(SMV_FALSE, next) : pstring(SMV_TRUE, next);
    case TRAN_FALSE:
      return neg ? pstring(SMV_TRUE, next) : pstring(SMV_FALSE, next);
    case TRAN_CLASSES:
      return pin(anno_cls, CLASS_VAR, tran_list(tran), neg, next);
    case TRAN_PERMISSIONS:
      return pin(anno_perm, PERMISSION_VAR, tran_list(tran), neg, next);
    case TRAN_TYPES:
      return pin(anno_type, TYPE_VAR, tran_list(tran), neg, next);
    case NEXT_TYPES:
      return pin(anno_type, NEXT_TYPE_VAR, tran_list(tran), neg, next);
    case SAME_TYPES:
      return psame(TYPE_VAR, neg, next);
    case TRAN_ROLES:
      return pin(anno_role, ROLE_VAR, tran_list(tran), neg, next);
    case NEXT_ROLES:
      return pin(anno_role, NEXT_ROLE_VAR, tran_list(tran), neg, next);
    case SAME_ROLES:
      return psame(ROLE_VAR, neg, next);
    case TRAN_USERS:
      return pin(anno_user, USER_VAR, tran_list(tran), neg, next);
    case NEXT_USERS:
      return pin(anno_user, NEXT_USER_VAR, tran_list(tran), neg, next);
    case SAME_USERS:
      return psame(USER_VAR, neg, next);
    case TRAN_NOT:
      return ptran(tran_arg1(tran), !neg, tran_prec, next);
    case TRAN_AND:
      /* assert neg is zero */
      return pbinary(tran, tran_prec, SMV_AND " ", next);
    case TRAN_OR:
      /* assert neg is zero */
      return pbinary(tran, tran_prec, SMV_OR " ", next);
    case TRAN_IMPLY:
      /* assert neg is zero */
      return pbinary(tran, tran_prec, SMV_IMPLY " ", next);
    case TRAN_IFF:
      /* assert neg is zero */
      return pbinary(tran, tran_prec, SMV_IFF " ", next);
    case TRAN_OKAY:
      if (neg)
	return pstring(SMV_NOT OKAY_VAR, next);
      else
	return pstring(OKAY_VAR, next);
    case TRAN_AX:
      return paxop(tran_arg1(tran), neg, next);
    case TRAN_NEXT:
      return pnextop(tran_arg1(tran), neg, next);
    case TRAN_UNTIL:
      return puntil(tran_arg1(tran), tran_arg2(tran), neg, next);
    default:
      fatal("Internal error in ptrans: %s line %d\n", __FILE__, __LINE__);
      return 0;
    }
  }
}

static void
emit_state_var(const char* var,  psym_t psym, symbol_list_t list)
{
  pretty_t p = pstring(";", PPRINT_NULL);
  p = pset(psym, list, p);
  p = pbreak(1, p);
  p = pstring(":", p);
  p = pstring(var, p);
  p = pstring(" VAR ", p);
  p = pblock(INDENT, p, PPRINT_NULL);
  pp(p);
  printf("\n");
}

static void
emit_decls(lts_t lts)
{
  symbol_list_t types = lts_types(lts);
  symbol_list_t roles = lts_roles(lts);
  symbol_list_t users = lts_users(lts);
  symbol_list_t classes = lts_classes(lts);
  symbol_list_t permissions = lts_permissions(lts);
  printf("MODULE main\n\n");
  emit_state_var(TYPE_VAR, anno_type, types);
  emit_state_var(ROLE_VAR, anno_role, roles);
  emit_state_var(USER_VAR, anno_user, users);
  emit_state_var(CLASS_VAR, anno_cls, classes);
  emit_state_var(PERMISSION_VAR, anno_perm, permissions);
  printf(" VAR %s: boolean;\n", OKAY_VAR);
}

static void
emit_comment(const char *title)
{
  printf("\n-- %s\n", title);
}

static void
emit_section(const char *section)
{
  printf("\n %s\n\n", section);
}

static void
emit_init(lts_t lts)
{
  emit_section("INIT");
  tran_t init = lts_initial(lts);
  pretty_t p = ptran(init, 0, IMPLY, PPRINT_NULL);
  p = pstring(OKAY_VAR " " SMV_IMPLY, pbreak(1, p));
  p = pblock(INDENT, pspace(INDENT, p), PPRINT_NULL);
  pp(p);
  printf("\n");
}

static void
emit_trans(lts_t lts)
{
  emit_section("TRANS");
  tran_t trans = lts_transition(lts);
  pretty_t p = ptran(trans, 0, AND, PPRINT_NULL);
  p = pbreak(1, pstring(OKAY_VAR, pbreak(1, pstring(SMV_AND " ", p))));
  p = pstring(NEXT_PREFIX OKAY_VAR NEXT_SUFFIX " "  SMV_IFF, p);
  p = pblock(INDENT, pspace(INDENT, p), PPRINT_NULL);
  pp(p);
  printf("\n");
}

/* spec translation */

/* Convert a next state expression into a current state expression */
static tran_t
unnext(tran_t tran)
{
  switch (tran_type(tran)) {
  case TRAN_TRUE:
  case TRAN_FALSE:
  case TRAN_CLASSES:
  case TRAN_PERMISSIONS:
    return tran;
  case NEXT_TYPES:
    return mk_tran_types(tran_list(tran));
  case NEXT_ROLES:
    return mk_tran_roles(tran_list(tran));
  case NEXT_USERS:
    return mk_tran_users(tran_list(tran));
  case TRAN_NOT:
    return con_tran_not(unnext(tran_arg1(tran)));
  case TRAN_AND:
    return con_tran_and(unnext(tran_arg1(tran)), unnext(tran_arg2(tran)));
  case TRAN_OR:
    return con_tran_or(unnext(tran_arg1(tran)), unnext(tran_arg2(tran)));
  case TRAN_IMPLY:
    return con_tran_imply(unnext(tran_arg1(tran)), unnext(tran_arg2(tran)));
  case TRAN_IFF:
    return mk_tran_iff(unnext(tran_arg1(tran)), unnext(tran_arg2(tran)));
  default:
    fatal("Internal error in unnext: %s line %d\n", __FILE__, __LINE__);
    return 0;
  }
}

static void
emit_spec(tran_t tran)
{
  emit_section("SPEC");
  if (tran_type(tran) != TRAN_IMPLY)
    fatal("bad spec form: not an implication");
  tran_t spec = tran_arg2(tran);
  spec = unnext(spec);
  spec = con_tran_imply(mk_tran_okay(), spec);
  spec = mk_tran_ax(spec);
  spec = con_tran_imply(tran_arg1(tran), spec);
  pretty_t p = ptran(spec, 0, MIN_PREC, PPRINT_NULL);
  p = pblock(INDENT, pspace(INDENT, p), PPRINT_NULL);
  pp(p);
  printf("\n");
}

static int
from_lts(lts_t lts)
{
  if (!lts)
    return 0;
  emit_decls(lts);
  emit_comment("Initial States");
  emit_init(lts);
  emit_comment("Transition Relation");
  emit_trans(lts);
  tran_list_t specs = lts_specifications(lts);
  if (specs) {
    emit_comment("Specifications");
    do {
      emit_spec(tran_list_head(specs));
      specs = tran_list_tail(specs);
    }
    while (specs);
  }
  return 0;
}

/* diagrams */

static tran_t
diag_ex(diagram_t diag)
{
  if (diagram_except_state(diag)) {
    if (diagram_except_action(diag))
      return con_tran_or(diagram_except_state(diag),
			 diagram_except_action(diag));
    else
      return diagram_except_state(diag);
  }
  else if (diagram_except_action(diag))
    return diagram_except_action(diag);
  else
    return mk_tran_false();
}

static size_t
arrow_list_length(arrow_list_t arrows)
{
  size_t len = 0;
  for (; arrows; arrows = arrow_list_tail(arrows))
    len++;
  return len;
}

/* The key action assertion specification function */
static tran_t
action_spec(size_t i, size_t j, arrow_list_t arrows, tran_t ex, tran_t last)
{
  arrow_t arrow = arrow_list_head(arrows);
  tran_t state = arrow_state(arrow);
  tran_t action = arrow_action(arrow);
  tran_t action_bar = con_tran_and(action, con_tran_not(ex));
  if (i <= j) {
    tran_t g = con_tran_and(last, mk_tran_okay());
    g = mk_tran_until(con_tran_not(ex), g);
    tran_t next;
    if (arrow_list_tail(arrows))
      next = arrow_state(arrow_list_head(arrow_list_tail(arrows)));
    else
      next = last;
    tran_t h = con_tran_and(con_tran_not(next), g);
    if (arrow_more(arrow)) {
      h = con_tran_and(con_tran_not(action), h);
      h = mk_tran_until(con_tran_not(con_tran_or(next, ex)), h);
    }
    tran_t f = mk_tran_next(h);
    f = con_tran_and(action_bar, f);
    f = con_tran_or(con_tran_and(con_tran_not(action), g), f);
    return con_tran_and(state, f);
  }
  else {
    tran_t f = action_spec(i, j + 1, arrow_list_tail(arrows), ex, last);
    if (arrow_more(arrow))
      f = mk_tran_until(action_bar, f);
    f = mk_tran_next(f);
    f = con_tran_and(action_bar, f);
    return con_tran_and(state, f);
  }
}

static void
action_assertions(diagram_t diag)
{
  arrow_list_t arrows = diagram_arrows(diag);
  tran_t ex = diag_ex(diag);
  tran_t last = diagram_state(diag);
  size_t len = arrow_list_length(arrows);
  size_t i;

  for (i = 0; i < len; i++) {
    tran_t a = con_tran_not(action_spec(i, 0, arrows, ex, last));
    emit_section("SPEC");
    pretty_t p = ptran(a, 0, MIN_PREC, PPRINT_NULL);
    p = pblock(INDENT, pspace(INDENT, p), PPRINT_NULL);
    pp(p);
    printf("\n");
  }
}

/* The key order assertion specification function */
static tran_t
order_spec(tran_t ex, tran_t first, tran_t state, tran_t next, tran_t last)
{
  tran_t t = con_tran_and(last, mk_tran_okay());
  t = mk_tran_until(con_tran_not(ex), t);
  t = con_tran_and(next, t);
  t = mk_tran_until(con_tran_not(con_tran_or(ex, state)), t);
  t = con_tran_and(first, t);
  return con_tran_not(t);
}

static void
order_assertions(diagram_t diag)
{
  arrow_list_t arrows = diagram_arrows(diag);
  if (!arrows)			/* Too short for order assertions */
    return;
  tran_t ex = diag_ex(diag);
  tran_t first = arrow_state(arrow_list_head(arrows));
  tran_t last = diagram_state(diag);
  arrows = arrow_list_tail(arrows);

  while (arrows) {
    tran_t state = arrow_state(arrow_list_head(arrows));
    arrows = arrow_list_tail(arrows);
    tran_t next;
    if (arrows)
      next = arrow_state(arrow_list_head(arrows));
    else
      next = last;
    tran_t a = order_spec(ex, first, state, next, last);
    emit_section("SPEC");
    pretty_t p = ptran(a, 0, MIN_PREC, PPRINT_NULL);
    p = pblock(INDENT, pspace(INDENT, p), PPRINT_NULL);
    pp(p);
    printf("\n");
  }
}

static void
from_diag(diagram_t diag)
{
  emit_comment("Event Assertions");
  action_assertions(diag);
  emit_comment("Order Assertions");
  order_assertions(diag);
}

static int
from_diags(diagram_list_t diagrams)
{
  for (; diagrams; diagrams = diagram_list_tail(diagrams))
    from_diag(diagram_list_head(diagrams));
  return 0;
}

int
gensmv(lts_t lts, diagram_list_t diagrams)
{
  if (setjmp(env))
    return 1;
  pprint_init();
  return from_lts(lts) || from_diags(diagrams);
}
