#include "stdafx.h"
#include "Driver.h"
#include "Exception.h"
#include "Core/Convert.h"

#ifdef POSIX
#include <dlfcn.h>
#endif

#ifndef LOAD_LIBRARY_SEARCH_SYSTEM32
#define LOAD_LIBRARY_SEARCH_SYSTEM32 0x00000800
#endif

namespace sql {

	// Note: These are globals, as loading of dynamic libraries is truly a global concern. It is not
	// tied to different Engines.

	// Lock to protect globals here.
	util::Lock pgDriverLock;

	// Global object that contains function pointers to the library.
	PGDriver pgDriver;

	// Instances that are active.
	size_t pgInstances = 0;

	static void throwNameError(Engine &e, const char *name) {
		StrBuf *msg = new (e) StrBuf();
		*msg << S("Could not find the function '") << toWChar(e, name)->v << S("' in the pqsql library.");
		throw new (e) SQLError(msg->toS());
	}

#if defined(WINDOWS)

	// MSVC "hack" to get the address of the module:
	extern "C" IMAGE_DOS_HEADER __ImageBase;

	// Try to load a DLL relative to the specified module.
	static HMODULE tryLoad(HMODULE relative, const wchar_t *name, StrBuf *error) {
		std::vector<wchar_t> buffer(MAX_PATH + 1, 0);
		do {
			GetModuleFileName(relative, &buffer[0], DWORD(buffer.size()));
			if (GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
				buffer.resize(buffer.size() * 2);
				continue;
			}
		} while (false);

		// Find the last backslash:
		size_t pathEnd = buffer.size();
		while (pathEnd > 0) {
			if (buffer[pathEnd - 1] == '\\' || buffer[pathEnd - 1] == '/')
				break;
			pathEnd--;
		}

		// Try the path we are in.
		buffer.resize(pathEnd + wcslen(name) + 1);
		for (size_t j = 0; name[j]; j++)
			buffer[pathEnd + j] = name[j];
		buffer[buffer.size() - 1] = 0;

		// Does the file exist at all?
		if (GetFileAttributes(&buffer[0]) == INVALID_FILE_ATTRIBUTES)
			return NULL;

		// The LOAD_WITH_ALTERED_SEARCH_PATH is important to ensure that the dependent dlls are
		// found by the loader.
		HMODULE result = LoadLibraryEx(&buffer[0], NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
		DWORD lastError = GetLastError();
		if (lastError == ERROR_BAD_EXE_FORMAT) {
			*error << S("\nThe file ") << &buffer[0] << S(" is for an incompatible platform.");
		} else if (lastError == ERROR_MOD_NOT_FOUND) {
			*error << S("\nOne of the dependent DLLs for the file ") << &buffer[0] << S(" was not found.");
		}
		return result;
	}

	void loadPgDriver(Engine &e) {
		const wchar_t *name = L"libpq.dll";

		HMODULE moduleHandle = (HMODULE)&__ImageBase;

#if 0
		// Try to get the current DLL handle. Only needed if we are not compiling on MSVC.
		if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS
								| GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
								(LPWSTR)&findDriverLib, // ptr to some function in the modle
								&moduleHandle)) {
			throw new (e) SQLError(new (e) Str(S("Could not find the path to the SQL lib DLL file.")));
		}
#endif

		StrBuf *errors = new (e) StrBuf();
		*errors << S("Could not find and load the PostgreSQL driver (libpq.dll).");

		// Relative to the SQL dll.
		HMODULE lib = tryLoad(moduleHandle, name, errors);
		if (!lib)
			// Relative to the Storm.exe.
			lib = tryLoad(NULL, name, errors);

		if (!lib)
			// In system paths.
			lib = LoadLibraryEx(name, NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);

		if (!lib)
			throw new (e) SQLError(errors->toS());

#define PG_OPTIONAL_FUNCTION(name, result, ...)							\
		pgDriver.name = (PGDriver::Ptr ## name)GetProcAddress(lib, #name);

#define PG_FUNCTION(name, result, ...)									\
		PG_OPTIONAL_FUNCTION(name, result, __VA_ARGS__)					\
		if (!pgDriver.name)												\
			throwNameError(e, #name);

#include "DriverFns.inc"

#undef PG_FUNCTION
#undef PG_OPTIONAL_FUNCTION
	}


#elif defined(POSIX)

	void loadPgDriver(Engine &e) {
		const char *names[] = {
			"libpq.so.5",
			"libpq.so",
			null,
		};

		void *loaded = null;
		for (const char **name = names; loaded == null && *name != null; name++) {
			// Note: This respects rpath of the library if se wet it.
			loaded = dlopen(*name, RTLD_NOW | RTLD_LOCAL);
		}

		if (!loaded) {
			StrBuf *msg = new (e) StrBuf();
			*msg << S("Failed to load the MariaDB library. Is it installed?\n");
			*msg << S("Searched for the following names:");
			for (const char **name = names; loaded == null && *name != null; name++) {
				*msg << S("\n") << toWChar(e, *name)->v;
			}

			throw new (e) SQLError(msg->toS());
		}

#define PG_OPTIONAL_FUNCTION(name, result, ...)							\
		pgDriver.name = (PGDriver::Ptr ## name)dlsym(loaded, #name);
#define PG_FUNCTION(name, result, ...)									\
		PG_OPTIONAL_FUNCTION(name, result, __VA_ARGS__)					\
		if (!pgDriver.name)												\
			throwNameError(e, #name);

#include "DriverFns.inc"

#undef PG_FUNCTION
#undef PG_OPTIONAL_FUNCTION
	}


#endif

	const PGDriver *createPGDriver(Engine &e) {
		util::Lock::L z(pgDriverLock);

		if (!pgDriver.PQconnectdb) {
			loadPgDriver(e);
		}

		pgInstances++;
		return &pgDriver;
	}

	void destroyPGDriver(const PGDriver *driver) {
		util::Lock::L z(pgDriverLock);
		pgInstances--;

		// TODO: Unload after a delay?
	}

}
