# GNU Enterprise Common - MySQL DB Driver - Schema Creation
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Creation.py 6851 2005-01-03 20:59:28Z jcater $

import os

from gnue.common.datasources.drivers.DBSIG2.Schema.Creation import \
    Creation as Base


# =============================================================================
# Class implementing schema creation for MySQL (3.x/4.x)
# =============================================================================

class Creation (Base.Creation):

  MAX_NAME_LENGTH = 64

  # ---------------------------------------------------------------------------
  # Create a new database
  # ---------------------------------------------------------------------------

  def createDatabase (self):
    """
    This function creates a new database as specified by the given connection.
    In order to be successfull the current account must have enough privileges
    to create new databases.
    """

    dbname   = self.connection.parameters.get ('dbname')
    username = self.connection.parameters.get ('username', 'gnue')
    password = self.connection.parameters.get ('password')
    host     = self.connection.parameters.get ('host')
    port     = self.connection.parameters.get ('port')

    createdb = u"mysqladmin %(site)s%(port)s create %(db)s" \
        % {'db'  : dbname,
           'site': host and "--host=%s " % host or '',
           'port': port and "--port=%s " % port or ''}

    os.system (createdb)

    sql = u"GRANT ALL PRIVILEGES ON %(db)s.* TO '%(user)s'@'%%' %(pass)s" \
        % {'db'  : dbname,
           'user': username,
           'pass': password and "IDENTIFIED BY '%s'" % password or ""}

    grant = 'mysql %(host)s%(port)s -e "%(sql)s" -s %(db)s' \
        % {'sql' : sql,
           'host': host and "--host=%s " % host or '',
           'port': port and "--port=%s " % port or '',
           'db'  : dbname}
    os.system (grant)

    sql = u"GRANT ALL PRIVILEGES ON %(db)s.* TO '%(user)s'@'localhost' " \
           "%(pass)s" \
        % {'db': dbname,
           'user': username,
           'pass': password and "IDENTIFIED BY '%s'" % password or ""}

    grant = 'mysql %(host)s%(port)s -e "%(sql)s" -s %(db)s' \
        % {'sql' : sql,
           'host': host and "--host=%s " % host or '',
           'port': port and "--port=%s " % port or '',
           'db'  : dbname}
    os.system (grant)


  # ---------------------------------------------------------------------------
  # Handle special defaults
  # ---------------------------------------------------------------------------

  def _defaultwith (self, code, tableName, fieldDef, forAlter):
    """
    This function adds 'auto_increment' for 'serials' and checks for the proper
    fieldtype on 'timestamps'

    @param code: code-tuple to merge the result in
    @param tableName: name of the table
    @param fieldDef: dictionary describing the field with the default
    @param forAlter: TRUE if the definition is used in a table modification
    """
    if fieldDef ['defaultwith'] == 'serial':
      seq = self._getSequenceName (tableName, fieldDef)
      code [1] [-1] += " AUTO_INCREMENT"
      fieldDef ['default'] = "nextval ('%s')" % seq

    elif fieldDef ['defaultwith'] == 'timestamp':
      if fieldDef ['type'] != 'timestamp':
        fieldDef ['type'] = 'timestamp'

        code [1].pop ()
        code [1].append (self._composeField (tableName, fieldDef, forAlter))
        
        print u_("WARNING: changing column type of '%(table)s.%(column)s' "
                 "to 'timestamp'") \
              % {'table': tableName,
                 'column': fieldDef ['name']}


  # ---------------------------------------------------------------------------
  # Drop an old index
  # ---------------------------------------------------------------------------

  def dropIndex (self, tableName, indexName, codeOnly = False):
    """
    This function drops an index from the given table

    @param tableName: name of the table to drop an index from
    @param indexName: name of the index to be dropped
    @param codeOnly: if TRUE no operation takes place, but only the code will
        be returned.
    @return: a tuple of sequences (prologue, body, epliogue) containing the
        code to perform the action.
    """

    res = ([], [], [])

    indexName = self._shortenName (indexName)
    res [0].append (u"DROP INDEX %s ON %s%s" \
                    % (indexName, tableName, self.END_COMMAND))

    if not codeOnly:
      self._executeCodeTuple (res)

    return res

  # ---------------------------------------------------------------------------
  # A key is an unsigned integer
  # ---------------------------------------------------------------------------

  def key (self, fieldDefinition):
    """
    Native datatype for a 'key'-field is 'unsigned integer'

    @param fieldDefinition: dictionary describing the field
    @return: 'int unsigned'
    """
    return "int unsigned"


  # ---------------------------------------------------------------------------
  # Translate a string into an apropriate native type
  # ---------------------------------------------------------------------------

  def string (self, fieldDefinition):
    """
    This function returns an apropriate native type for a string. If the length
    is given and below 255 character the result is a varchar, otherwist text.

    @param fieldDefinition: dictionary describing the field
    @return: string with the native datatype
    """
    if fieldDefinition.has_key ('length') and fieldDefinition ['length'] <= 255:
      return "varchar (%s)" % fieldDefinition ['length']
    else:
      return "text"


  # ---------------------------------------------------------------------------
  # Create an apropriate type for a number
  # ---------------------------------------------------------------------------

  def number (self, fieldDefinition):
    """
    This function returns an apropriate type for a number according to the
    given length and precision.

    @param fieldDefinition: dictionary describing the field
    @return: string with the native datatype
    """
    scale  = 0
    length = 0

    if fieldDefinition.has_key ('precision'):
      scale = fieldDefinition ['precision']
    if fieldDefinition.has_key ('length'):
      length = fieldDefinition ['length']

    if scale == 0:
      if length <= 4:
        return "smallint"
      elif length <= 9:
        return "int"
      elif length <= 18:
        return "bigint"
      else:
        return "decimal (%s,0)" % length
    else:
      return "decimal (%s,%s)" % (length, scale)


  # ---------------------------------------------------------------------------
  # MySQL has no native boolean type
  # ---------------------------------------------------------------------------

  def boolean (self, fieldDefinition):
    """
    MySQL has no native boolean type, so this function returns 'tinyint (1)
    unsigned' instead.

    @param fieldDefinition: dictionary describing the field
    @return: 'tinyint (1) unsigned'
    """
    return "tinyint (1) unsigned"


  # ---------------------------------------------------------------------------
  # MySQL has a timestamp, which is needed for 'defaultwith timestamp'
  # ---------------------------------------------------------------------------

  def timestamp (self, fieldDefinition):
    """
    In MySQL timestamps are used for default values, otherwise we map to the
    inherited typetransformation of 'timestamp'.

    @param fieldDefinition: dictionary describing the field
    @return: string with the native datatype
    """
    if fieldDefinition.has_key ('defaultwith') and \
        fieldDefinition ['defaultwith'] == 'timestamp':
      return "timestamp"
    else:
      return Base.Creation.timestamp (self, fieldDefinition)

