/**
   @title  libtrackmem.cpp
   @author Emery Berger <http://www.cs.umass.edu/~emery>
   @brief  A shim to track memory consumption.

 */


#if defined(_WIN32)
#error "Windows not currently supported."
#endif

#include <dlfcn.h>
#include <sys/mman.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>

class OutputResults {
public:
  OutputResults (void)
    : currMem (0),
      maxMem (0)
    {}
  ~OutputResults (void) {
    fprintf (stderr, "Yee hah!\n");
    char buf[255];
    sprintf (buf, "Current memory in use = %d, max = %d\n", currMem, maxMem);
    fprintf (stderr, buf);
  }
  void add (size_t sz) {
    currMem += sz;
    if (currMem > maxMem) {
      maxMem = currMem;
    }
  }
  size_t currMem;
  size_t maxMem;
};

OutputResults stats;

int v;

extern "C" {

  typedef void * (*mmapFunction) (void *, size_t, int, int, int, off_t);
  typedef int (*munmapFunction) (void *, size_t);
  typedef int (*brkFunction) (void *);
  typedef void * (*sbrkFunction) (intptr_t);
  typedef void (*exitFunction) (int);

  void * mmap (void * addr, size_t len, int prot, int flags, int fildes, off_t off) {
    static mmapFunction * realfn
      = (mmapFunction *) dlsym (RTLD_NEXT, "mmap");
    void * result = (*realfn)(addr, len, prot, flags, fildes, off);
    if ((int) result != -1) {
      stats.add (len);
    }
    return result;
  }

  int munmap (void * addr, size_t len) {
    static munmapFunction * realfn
      = (munmapFunction *) dlsym (RTLD_NEXT, "munmap");
    int result = (*realfn)(addr, len);
    if (result != -1) {
      stats.currMem -= len;
    }
    return result;
  }

#if 0
  void * sbrk (intptr_t incr) {
    static sbrkFunction * realfn
      = (sbrkFunction *) dlsym (RTLD_NEXT, "sbrk");
    if (incr == 0) {
      return (*realfn)(0);
    }
    void * result = (*realfn)(incr);
    if ((int) result != -1) {
      stats.add (incr);
    }
    return result;
  }


  int brk (void * endds) {
    static brkFunction * realfn
      = (brkFunction *) dlsym (RTLD_NEXT, "brk");
    static void * currentBreakpoint = sbrk(0);
    int result = (*realfn)(endds);
    if (result != -1) {
      if ((size_t) endds < (size_t) currentBreakpoint) {
	// Shrinking breakpoint.
	size_t delta = (size_t) currentBreakpoint - (size_t) endds;
	stats.currMem -= delta;
      } else {
	// Growing breakpoint.
	size_t delta = (size_t) currentBreakpoint - (size_t) endds;
	stats.add (delta);
      }
      currentBreakpoint = endds;
    }
    return result;
  }
#endif

}


