// Copyright (c) 2000, 2001, 2002, 2003 by David Scherer and others.
// See the file license.txt for complete license terms.
// See the file authors.txt for a complete list of contributors.

// Most functions are inlined in this header.

#include "vector.h"
#include "tmatrix.h"
#include "exceptions.h"

#include <istream>
#include <ostream>
#include <cmath>
#include <stdexcept>
#include <sstream>

#include <boost/python/class.hpp>
#include <boost/python/def.hpp>
#include <boost/python/implicit.hpp>
#include <boost/python/operators.hpp>
#include <boost/python/init.hpp>
#include <boost/python/overloads.hpp>
#include <boost/python/return_value_policy.hpp>
#include <boost/python/copy_const_reference.hpp>
#include <boost/python/tuple.hpp>
#include <boost/python/to_python_converter.hpp>

#include <boost/python/numeric.hpp>
#include "num_util.h"

namespace visual {

using boost::python::numeric::array;
using boost::python::object;
// Operations on Numeric arrays
namespace {

void
validate_array( const array& arr)
{
	std::vector<int> dims = shape(arr);
	if (type(arr) != double_t) {
		throw std::invalid_argument( "Array must be of type Float64.");
	}
	if (!iscontiguous(arr)) {
		throw std::invalid_argument( "Array must be contiguous."
			"(Did you pass a slice?)");
	}
	if (dims.size() != 2) {
		if (dims.size() == 1 && dims[0] == 3)
			return;
		else
			throw std::invalid_argument( "Array must be Nx3 in shape.");
	}
	if (dims[1] != 3) {
		throw std::invalid_argument( "Array must be Nx3 in shape.");
	}
}

} // !namespace anonymous


object
mag_a( const array& arr)
{
	validate_array( arr);
	std::vector<int> dims = shape(arr);
	// Magnitude of a flat 3-length array
	if (dims.size() == 1 && dims[0] == 3) {
		return object( vector(arr).mag());
	}
	std::vector<int> rdims(1);
	rdims[0] = dims[0];
	array ret = makeNum( rdims);
	for (int i = 0; i< rdims[0]; ++i) {
		ret[i] = vector(arr[i]).mag();
	}
	return ret;
}

object
mag2_a( const array& arr)
{
	validate_array( arr);
	std::vector<int> dims = shape(arr);
	if (dims.size() == 1 && dims[0] == 3) {
		// Returns an object of type float.
		return object( vector(arr).mag2());
	}
	std::vector<int> rdims(1);
	rdims[0] = dims[0];
	array ret = makeNum( rdims);
	for (int i = 0; i < rdims[0]; ++i) {
		ret[i] = vector(arr[i]).mag2();
	}
	// Returns an object of type Numeric.array.
	return ret;
}

object
norm_a( const array& arr)
{
	validate_array( arr);
	std::vector<int> dims = shape(arr);
	if (dims.size() == 1 && dims[0] == 3) {
		// Returns a float
		return object( vector(arr).norm());
	}
	array ret = makeNum(dims);
	for (int i = 0; i < dims[0]; ++i) {
		ret[i] = vector(arr[i]).norm();
	}
	// Returns a Numeric.array
	return ret;
}

array
dot_a( const array& arg1, const array& arg2)
{
	validate_array( arg1);
	validate_array( arg2);
	std::vector<int> dims1 = shape(arg1);
	std::vector<int> dims2 = shape(arg2);
	if (dims1 != dims2) {
		throw std::invalid_argument( "Array shape mismatch.");
	}
	
	std::vector<int> dims_ret(1);
	dims_ret[0] = dims1[0];
	array ret = makeNum( dims_ret);
	const double* arg1_i = (double*)data(arg1);
	const double* arg2_i = (double*)data(arg2);
	for ( int i = 0; i < dims1[0]; ++i, arg1_i +=3, arg2_i += 3) {
		ret[i] = vector(arg1_i).dot( vector(arg2_i));
	}
	return ret;
}

array
cross_a_a( const array& arg1, const array& arg2)
{
	validate_array( arg1);
	validate_array( arg2);
	std::vector<int> dims1 = shape(arg1);
	std::vector<int> dims2 = shape(arg2);
	if (dims1 != dims2) {
		throw std::invalid_argument( "Array shape mismatch.");
	}

	array ret = makeNum( dims1);
	const double* arg1_i = (double*)data(arg1);
	const double* arg2_i = (double*)data(arg2);
	double* ret_i = (double*)data(ret);
	double* const ret_stop = ret_i + 3*dims1[0];
	for ( ; ret_i < ret_stop; ret_i += 3, arg1_i += 3, arg2_i += 3) {
		vector ret = vector(arg1_i).cross( vector( arg2_i));
		ret_i[0] = ret.get_x();
		ret_i[1] = ret.get_y();
		ret_i[2] = ret.get_z();
	}
	return ret;	
}

array
cross_a_v( const array& arg1, const vector& arg2)
{
	validate_array( arg1);
	std::vector<int> dims = shape( arg1);
	array ret = makeNum( dims);
	const double* arg1_i = (double*)data( arg1);
	double* ret_i = (double*)data( ret);
	double* const ret_stop = ret_i + 3*dims[0];
	for ( ; ret_i < ret_stop; ret_i += 3, arg1_i += 3) {
		vector ret = vector( arg1_i).cross( arg2);
		ret_i[0] = ret.get_x();
		ret_i[1] = ret.get_y();
		ret_i[2] = ret.get_z();
		
	}
	return ret;
}

array
cross_v_a( const vector& arg1, const array& arg2)
{
	validate_array( arg2);
	std::vector<int> dims = shape( arg2);
	array ret = makeNum( dims);
	const double* arg2_i = (double*)data( arg2);
	double* ret_i = (double*)data( ret);
	double* const ret_stop = ret_i + 3*dims[0];
	for ( ; ret_i < ret_stop; ret_i += 3, arg2_i += 3) {
		vector ret = arg1.cross( vector( arg2_i));
		ret_i[0] = ret.get_x();
		ret_i[1] = ret.get_y();
		ret_i[2] = ret.get_z();
		
	}
	return ret;

}

	
vector 
vector::cross( const vector& v) const throw()
{
	vector ret( this->y*v.z - this->z*v.y
		, this->z*v.x - this->x*v.z
		, this->x*v.y - this->y*v.x);
	return ret;
}

/* This function represents the conversion from any Python sequence type to a 
 * visual::vector.  Ideally, we could specify an implicit conversion to be
 * applied whenever explicit argument lookup failed in all cases, but we cannot.
 * So, in nearly all of the places where a vector argument is required (mostly 
 * set_foo() functions), we also explicitly provide an overload for a Python
 * object and perform the conversion here.  Note that this overload will match
 * *anything* from Boost.Python's perspective, including another vector.  So,
 * since this function is much slower than simply accepting a vector argument,
 * we must ensure that Boost.Python tries the explicit vector overload before 
 * resorting to this one.  To make this happen we must rely on the fact that
 * Boost.Python tries to match a signature in the reverse order that they are
 * specified in class_<>::def() and def().
 *
 * Bottom line: Provide the generic overload (usually named set_foo_t() for
 * historical reasons) before the vector form of the overloaded function.
 * Failure to do this will not result in a compile-time error or run-time
 * exception, but will incur a heavy performance penalty.
 */
vector::vector( const boost::python::object& t)
	: x(0), y(0), z(0)
{
	int i = length(t);
	switch (i) {
		case 3:
			z = boost::python::extract<double>( t[2]);
			// FALLTHROUGH
		case 2:
			y = boost::python::extract<double>( t[1]);
			x = boost::python::extract<double>( t[0]);
			break;
		default:
			throw std::invalid_argument( 
			"Vectors must be constructed from sequences of 2 or 3 float members.");
	}
}
  
boost::python::tuple 
vector::as_tuple( void) const
{
	return boost::python::make_tuple( this->x, this->y, this->z);
}

vector 
vector::norm( void) const throw()
{
	double magnitude = this->mag();
	if (magnitude)
	// This step ensures that vector(0,0,0).norm() returns vector(0,0,0)
	// instead of NaN
		magnitude = 1.0 / magnitude;
	return vector( x*magnitude, y*magnitude, z*magnitude);
}


double 
vector::comp( const vector& v) const throw()
{
	return (this->dot( v) / v.mag());
}

vector 
vector::proj( const vector& v) const throw()
{
	return (this->dot( v)/v.mag2() * v);
}

bool 
vector::orthogonal( const vector& v) const throw()
{
	return ( this->dot( v) == 0.0);
}

double 
vector::diff_angle( const vector& v) const throw()
{
	double magfirst = this->mag2();
	double magsecond = v.mag2();
	if (magfirst == 0.0 || magsecond == 0.0)
		return (double) 0.0;
	else
		// By taking dot product of the normal vectors, we minimize the possible error.
		return std::acos( this->norm().dot( v.norm()) );
}
  
std::string 
vector::repr( void) const
{
	std::stringstream ret;
	ret.precision( std::numeric_limits<double>::digits10);
	// Since this function is inteded to produce Python code that can be used to 
	// rebuild this object, we use the full precision of the data type here.
	ret << "vector(" << x << ", " << y << ", " << z << ")";
	return ret.str();
}
  
double
vector::py_getitem( int index) const
{
	switch (index) {
		case -3:
			return x;
		case -2:
			return y;
		case -1:
			return z;
		case 0:
			return x;
		case 1:
			return y;
		case 2:
			return z;
		default:
			std::ostringstream s;
			s << "vector index out of bounds: " << index;
			throw std::out_of_range( s.str() );
	}
}

void
vector::py_setitem(int index, double value)
{
	switch (index) {
		case -3:
			x = value;
			break;
		case -2:
			y = value;
			break;
		case -1:
			z = value;
			break;
		case 0:
			x = value;
			break;
		case 1:
			y = value;
			break;
		case 2:
			z = value;
			break;
		default:
			std::ostringstream s;
			s << "vector index out of bounds: " << index;
			throw std::out_of_range( s.str() );
	}    	
}

void
vector::py_scale( double s)
{
	*this = norm()*s;
}

void
vector::py_scale2( double s2)
{
	*this = norm()*std::sqrt(s2);
}
        
bool 
vector::operator<( const vector& v) const throw()
{
	if (this->x != v.x) {
		return this->x < v.x;
	}
	else if (this->y != v.y) {
		return this->y < v.y;
	}
	else return this->z < v.z;
}

vector
vector::rotate( double angle, vector axis) const throw()
{
	tmatrix R;
	py_rotation( R, angle, axis, vector( 0,0,0));
	return R.times_v( *this);
}

bool
vector::linear_multiple_of( const vector& other) const
{
	vector _this = norm();
	vector _other = other.norm();
	
	return _this == _other || _this == -_other;
}

vector
vector_pos( const vector& self)
{
	return self;
}

void
shared_vector::set_x( const double& x)
{
	if (owner) {
		write_lock L( *owner);
		this->x = x;
	}
	else {
		this->x = x;
	}
}

void
shared_vector::set_y( const double& y)
{
	if (owner) {
		write_lock L( *owner);
		this->y = y;
	}
	else {
		this->y = y;
	}
}

void
shared_vector::set_z( const double& z)
{
	if (owner) {
		write_lock L( *owner);
		this->z = z;
	}
	else {
		this->z = z;
	}    	
}
	

const shared_vector&
shared_vector::operator=( const vector& v)
{
	if (owner) {
		write_lock L( *owner);
		this->x = v.x;
		this->y = v.y;
		this->z = v.z;
	}
	else {
		this->x = v.x;
		this->y = v.y;
		this->z = v.z;    		
	}    
	return *this;
}


const shared_vector&
shared_vector::operator=( boost::python::tuple t)
{
	write_lock L(*owner);
	vector v(t);
	this->x = v.x;
	this->y = v.y;
	this->z = v.z;
	return *this;
}

const shared_vector&
shared_vector::operator/=( const int& s)
{
	if (owner) {
		write_lock L( *owner);
		this->x /= s;
		this->y /= s;
		this->z /= s;
	}
	else {
		this->x /= s;
		this->y /= s;
		this->z /= s;
	}
	return *this;
}		

const shared_vector&
shared_vector::operator*=( const int& s)
{
	if (owner) {
		write_lock L( *owner);
		this->x *= s;
		this->y *= s;
		this->z *= s;
	}
	else {
		this->x *= s;
		this->y *= s;
		this->z *= s;
	}
	return *this;
}

const shared_vector&
shared_vector::operator+=( const vector& v)
{
	if (owner) {
		write_lock L( *owner);
		this->x += v.x;
		this->y += v.y;
		this->z += v.z;
	}
	else {
		this->x += v.x;
		this->y += v.y;
		this->z += v.z;
	}
	return *this;
}

const shared_vector&
shared_vector::operator-=( const vector& v)
{
	if (owner) {
		write_lock L( *owner);
		this->x -= v.x;
		this->y -= v.y;
		this->z -= v.z;
	}
	else {
		this->x -= v.x;
		this->y -= v.y;
		this->z -= v.z;
	}
	return *this;
}
    
const shared_vector&
shared_vector::operator*=( const double& s)
{
	if (owner) {
		write_lock L( *owner);
		this->x *= s;
		this->y *= s;
		this->z *= s;
	}
	else {
		this->x *= s;
		this->y *= s;
		this->z *= s;
	}
	return *this;
}
    
const shared_vector&
shared_vector::operator/=( const double& s)
{
	if (owner) {
		write_lock L( *owner);
		this->x /= s;
		this->y /= s;
		this->z /= s;
	}
	else {
		this->x /= s;
		this->y /= s;
		this->z /= s;
	}
	return *this;
}

void
shared_vector::py_setitem(int index, double value)
{
	if (owner) {
		write_lock L(*owner);
		vector::py_setitem(index, value);
	}
	else {
		vector::py_setitem(index, value);
	}
}

void
shared_vector::py_scale( double s)
{
	if (owner) {
		write_lock L(*owner);
		vector::py_scale(s);
	}
	else
		vector::py_scale(s);
}

void
shared_vector::py_scale2( double s)
{
	if (owner) {
		write_lock L(*owner);
		vector::py_scale2(s);
	}
	else
		vector::py_scale2(s);
}


namespace {
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS( vector_rotate, vector::rotate, 1, 2)
BOOST_PYTHON_FUNCTION_OVERLOADS( member_rotate, rotate, 2, 3)
} // !namespace anonymous

struct vector_from_tuple
{
	vector_from_tuple()
	{
		boost::python::converter::registry::push_back( 
			&convertible,
			&construct,
			boost::python::type_id<visual::vector>());
	}
	
	static void* convertible( PyObject* obj)
	{
		using boost::python::handle;
		using boost::python::allow_null;
		handle<> obj_iter( allow_null( PyObject_GetIter(obj)));
		if (!obj_iter.get()) {
			PyErr_Clear();
			return 0;
		}
		int obj_size = PyObject_Length(obj);
		if (obj_size < 0) {
			PyErr_Clear();
			return 0;
		}
		if (obj_size != 2 && obj_size != 3)
			return 0;
		return obj;
	}
	static void construct( 
		PyObject* _obj, 
		boost::python::converter::rvalue_from_python_stage1_data* data)
	{
		using namespace boost::python;
		
		object obj = object(handle<>(borrowed(_obj)));
		void* storage = (
			(boost::python::converter::rvalue_from_python_storage<vector>*)
			data)->storage.bytes;
		new (storage) vector( obj);
		data->convertible = storage;
	}
};

void vector_init_type()
{
	using namespace boost::python;
	
	// Numeric versions for some of the above
	// TODO: round out the set.
	def( "mag", mag_a);
	def( "dot", dot_a);
	def( "cross", cross_a_a);
	def( "cross", cross_a_v);
	def( "cross", cross_v_a);
	def( "mag2", mag2_a);
	def( "norm", norm_a);

	// Free functions for vectors
	def( "dot", dot, "The dot product between two vectors.");
	def( "cross", cross, "The cross product between two vectors.");
	def( "mag", mag, "The magnitude of a vector.");
	def( "mag2", mag2, "A vector's magnitude squared.");
	def( "norm", norm, "Returns the unit vector of its argument.");
	def( "comp", comp, "The scalar projection of arg1 to arg2.");
	def( "proj", proj, "The vector projection of arg1 to arg2.");
	def( "diff_angle", diff_angle, "The angle between two vectors, in radians.");
	def( "rotate", rotate, member_rotate(
		"rotate( vector, float angle, vector axis=vector(0,1,0)) -> vector\n"
		" Rotate a vector about an axis through an angle.",
		args("angle", "axis")));

	vector (vector::* truediv)( double) const = &vector::operator/;
	const vector& (vector::* itruediv)( double) = &vector::operator/=;
	
	// The vector class, constructable from 0, one, two or three doubles.	
	class_<vector>("vector", init< optional<double, double, double> >())
	    // Conversion from sequences.
		.def( init<const object&>())
		// Explicit copy.
		.def( init<vector>())
		// member variables.
		.def_readwrite( "x", &vector::x)
		.def_readwrite( "y", &vector::y)
		.def_readwrite( "z", &vector::z)
		// Member functions masquerading as properties.
		.add_property( "mag", &vector::mag, &vector::py_scale)
		.add_property( "mag2", &vector::mag2, &vector::py_scale2)
		// Member functions
		.def( "dot", &vector::dot, "The dot product of this vector and another.")
		.def( "cross", &vector::cross, "The cross product of this vector and another.")
		.def( "norm", &vector::norm, "The unit vector of this vector.")
		.def( "comp", &vector::comp, "The scalar projection of this vector onto another.")
		.def( "proj", &vector::proj, "The vector projection of this vector onto another.")
		.def( "diff_angle", &vector::diff_angle, "The angle between this vector "
			"and another, in radians.")
		.def( "clear", &vector::clear, "Zero the state of this vector.  Potentially "
			"useful for reusing a temporary variable.")
		.def( "rotate", &vector::rotate, vector_rotate( "Rotate this vector about "
			"the specified axis through the specified angle, in radians", 
			args( "angle", "axis")))
		.def( "__abs__", &vector::mag, "Return the length of this vector.")
		.def( "__pos__", vector_pos, "Returns the vector itself.")
		// Some support for the sequence protocol.
		.def( "__len__", &vector::py_len)
		.def( "__getitem__", &vector::py_getitem)
		.def( "__setitem__", &vector::py_setitem)
		// Use this to quickly convert vector's to tuples.
		.def( "astuple", &vector::as_tuple, "Convert this vector to a tuple.  "
			"Same as tuple(vector), but much faster.")
		// Member operators                          
		.def( -self)
		.def( self + self)
		.def( self += self)
		.def( self - self)
		.def( self -= self)
		.def( self * double())
		.def( self *= double())
		.def( self / double())
		.def( self /= double())
		.def( double() * self)
		// Same as self / double, when "from __future__ import division" is in effect.
		.def( "__itruediv__", itruediv, return_value_policy<copy_const_reference>())
		// Same as self /= double, when "from __future__ import division" is in effect.
		.def( "__truediv__",  truediv)
		.def( self_ns::str(self))        // Support ">>> print foo"
		.def( "__repr__", &vector::repr) // Support ">>> foo"
		;

	const shared_vector& (shared_vector::* sitruediv)( const double&) = 
		&shared_vector::operator/=;
	class_<shared_vector, bases<vector>, boost::noncopyable>( "Vector", no_init)
		.def( self += other<vector>())
		.def( self -= other<vector>())
		.def( self *= double())
		.def( self /= double())
		.def( "__itruediv__", sitruediv, return_value_policy<copy_const_reference>())
		.def( "__setitem__", &shared_vector::py_setitem)
		.add_property( "x", make_getter(&shared_vector::x), &shared_vector::set_x)
		.add_property( "y", make_getter(&shared_vector::y), &shared_vector::set_y)
		.add_property( "z", make_getter(&shared_vector::z), &shared_vector::set_z)
		.add_property( "mag", &shared_vector::mag, &shared_vector::py_scale)
		.add_property( "mag2", &shared_vector::mag2, &shared_vector::py_scale2)
		;

	// Allow automagic conversions from shared_vector to vector.
	implicitly_convertible<shared_vector, vector>();
	
	// Pass a sequence to some functions that expect type visual::vector.
	vector_from_tuple();
}


} // !namespace visual
