0
0
mirror of https://github.com/rbock/sqlpp11.git synced 2024-12-27 08:31:06 +08:00
sqlpp11/scripts/ddl2cpp

423 lines
16 KiB
Plaintext
Raw Permalink Normal View History

#!/usr/bin/env python
##
2015-02-15 19:00:21 +01:00
# Copyright (c) 2013-2015, Roland Bock
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.
##
from __future__ import print_function
import sys
import re
import os
2016-05-14 14:30:08 +03:00
2016-05-05 01:58:53 +03:00
# error codes, we should refactor this later
2016-05-05 23:08:27 +03:00
ERROR_BAD_ARGS = 1
2016-05-05 01:58:53 +03:00
ERROR_DATA_TYPE = 10
ERROR_STRANGE_PARSING = 20
2016-05-05 23:08:27 +03:00
2016-05-05 01:58:53 +03:00
from pyparsing import CaselessLiteral, Literal, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, Combine, Suppress, \
2016-05-14 16:11:48 +03:00
WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, \
Or, Group, ParseException
# HELPERS
2016-05-05 01:58:53 +03:00
def get_include_guard_name(namespace, inputfile):
2018-02-13 10:29:28 +01:00
val = re.sub("[^A-Za-z0-9]+", "_", namespace + '_' + os.path.basename(inputfile))
return val.upper()
def identity_naming_func(s):
return s
2016-05-05 01:58:53 +03:00
def repl_camel_case_func(m):
2016-05-05 01:58:53 +03:00
if m.group(1) == '_':
return m.group(2).upper()
else:
return m.group(1) + m.group(2).upper()
2016-05-14 16:11:48 +03:00
def class_name_naming_func(s):
return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s)
def member_name_naming_func(s):
return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s)
toClassName = class_name_naming_func
toMemberName = member_name_naming_func
2016-05-05 01:58:53 +03:00
def repl_func_for_args(m):
if m.group(1) == '-':
return m.group(2).upper()
else:
return m.group(1) + m.group(2).upper()
2016-05-05 01:58:53 +03:00
def setArgumentBool(s, bool_value):
first_lower = lambda s: s[:1].lower() + s[1:] if s else '' # http://stackoverflow.com/a/3847369/5006740
var_name = first_lower(re.sub("(\s|-|[0-9])(\S)", repl_func_for_args, s))
globals()[var_name] = bool_value
2016-05-14 14:30:08 +03:00
def usage(optionalArgs = {}):
print('\
Usage: ddl2cpp <path to ddl> <path to target (without extension, e.g. /tmp/MyTable)> <namespace>\n\
ddl2cpp -help')
def beginHeader(pathToHeader, nsList):
header = open(pathToHeader, 'w')
print('// generated by ' + ' '.join(sys.argv), file=header)
print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header)
print('#define '+get_include_guard_name(namespace, pathToHeader), file=header)
print('', file=header)
print('#include <' + INCLUDE + '/table.h>', file=header)
print('#include <' + INCLUDE + '/data_types.h>', file=header)
print('#include <' + INCLUDE + '/char_sequence.h>', file=header)
print('', file=header)
for ns in nsList:
print('namespace ' + ns, file=header)
print('{', file=header)
return header
def endHeader(header, nsList):
for ns in nsList:
print('} // namespace ' + ns, file=header)
print('#endif', file=header)
header.close()
2016-05-14 15:57:26 +03:00
def help_message():
arg_string = '\n'
2016-05-15 01:02:46 +03:00
pad = 0
padding = 0
for argument in list(optionalArgs.keys()):
if len(argument) > pad:
pad = len(argument)
2016-05-14 15:57:26 +03:00
for argument in list(optionalArgs.keys()):
if argument == '-help':
2016-05-14 14:30:08 +03:00
continue
2016-05-15 01:02:46 +03:00
if len(argument) < pad:
padding = " " * (pad - len(argument))
else:
padding = ''
arg_string = arg_string + ' [-[no]'+argument+']: ' + padding + optionalArgs[argument] + '\n'
2016-05-15 01:05:08 +03:00
print('Usage: ddl2cpp [-help]\n\n OPTIONAL ARGUMENTS:\n' + arg_string +' \n \
2016-05-15 01:02:46 +03:00
<path to ddl> <path to target> <namespace>\n\
\n\
<path to ddl> path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n\
<path to target> path to a generated C++ header file. Without extension (no *.h). \n\
<namespace> namespace you want. Usually a project/database name\n')
2016-05-14 14:30:08 +03:00
sys.exit(0)
2016-05-05 23:08:27 +03:00
2016-05-14 14:30:08 +03:00
optionalArgs = {
2016-05-05 23:08:27 +03:00
# if -some-key is present, it will set variable someKey to True
# if -no-some-key is present, it will set variable someKey to False
2016-05-14 14:30:08 +03:00
'-timestamp-warning': "show warning about mysql timestamp data type", # timeStampWarning = True
2016-05-05 23:08:27 +03:00
# '-no-time-stamp-warning' # timeStampWarning = False
2016-05-14 14:30:08 +03:00
'-fail-on-parse': "abort instead of silent genereation of unusable headers", # failOnParse = True
'-warn-on-parse': "warn about unusable headers, but continue", # warnOnParse = True
'-auto-id': "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True
'-identity-naming': "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True
'-split-tables': "Make a header for each table name, using target as a directory", # splitTables = True
2016-05-14 14:30:08 +03:00
'-help': "show this help"
}
2016-05-05 01:58:53 +03:00
2016-05-14 15:57:26 +03:00
if '-help' in sys.argv:
help_message()
2016-05-05 01:58:53 +03:00
# ARGUMENT PARSING
if len(sys.argv) < (4):
2016-05-05 23:08:27 +03:00
usage(optionalArgs)
sys.exit(ERROR_BAD_ARGS)
2016-05-05 01:58:53 +03:00
firstPositional = 1
timestampWarning = True
2016-05-05 02:40:44 +03:00
failOnParse = False
2016-05-05 23:08:27 +03:00
warnOnParse = False
parseError = "Parsing error, possible reason: can't parse default value for a field"
autoId = False
identityNaming = False
splitTables = False
2016-05-05 01:58:53 +03:00
if len(sys.argv) >= 4:
for arg in sys.argv:
2016-05-05 23:08:27 +03:00
noArg = arg.replace('-no-', '-')
2016-05-14 14:30:08 +03:00
if arg in list(optionalArgs.keys()):
2016-05-05 01:58:53 +03:00
setArgumentBool(arg, True)
firstPositional += 1
elif noArg in optionalArgs:
setArgumentBool(noArg, False)
firstPositional += 1
2016-05-05 02:40:44 +03:00
else:
2016-05-05 23:08:27 +03:00
pass
2016-05-05 01:58:53 +03:00
if identityNaming:
toClassName = identity_naming_func
toMemberName = identity_naming_func
2016-05-05 01:58:53 +03:00
pathToDdl = sys.argv[firstPositional]
2016-05-05 23:08:27 +03:00
pathToHeader = sys.argv[firstPositional + 1] + ('/' if splitTables else '.h')
2016-05-05 01:58:53 +03:00
namespace = sys.argv[firstPositional + 2]
2016-05-05 23:08:27 +03:00
2016-05-05 01:58:53 +03:00
INCLUDE = 'sqlpp11'
NAMESPACE = 'sqlpp'
# PARSER
def ddlWord(string):
return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_")
2016-05-14 16:11:48 +03:00
# This function should be refactored if we find some database function which needs parameters
# Right now it works only for something like NOW() in MySQL default field value
def ddlFunctionWord(string):
return CaselessLiteral(string) + OneOrMore("(") + ZeroOrMore(" ") + OneOrMore(")")
2016-06-09 19:18:15 +03:00
ddlString = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")])
2015-12-23 18:28:15 +01:00
negativeSign = Literal('-')
2016-06-09 19:18:15 +03:00
ddlNum = Combine(Optional(negativeSign) + Word(nums + "."))
2016-06-09 19:19:37 +03:00
ddlTerm = Word(alphanums + "_$")
2016-06-09 19:18:15 +03:00
ddlName = Or([ddlTerm, ddlString])
2018-04-18 19:02:06 +02:00
ddlMathOp = Word("+><=-")
ddlBoolean = Or([ddlWord("AND"), ddlWord("OR"), ddlWord("NOT")])
ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")"
2018-04-18 19:02:06 +02:00
ddlMathCond = "(" + delimitedList(
Or([
Group(ddlName + ddlMathOp + ddlName),
Group(ddlName + ddlWord("NOT") + ddlWord("NULL")),
]),
delim=ddlBoolean) + ")"
2017-11-06 21:37:10 +01:00
ddlUnsigned = ddlWord("unsigned").setResultsName("isUnsigned")
ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull")
2016-05-05 01:58:53 +03:00
ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue")
2017-03-05 19:13:26 +01:00
ddlAutoValue = Or([
ddlWord("AUTO_INCREMENT"),
ddlWord("AUTOINCREMENT"),
2017-03-05 19:13:26 +01:00
ddlWord("SMALLSERIAL"),
ddlWord("SERIAL"),
ddlWord("BIGSERIAL"),
]).setResultsName("hasAutoValue")
2016-06-09 19:18:15 +03:00
ddlColumnComment = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment")
ddlConstraint = Or([
ddlWord("CONSTRAINT"),
ddlWord("PRIMARY"),
ddlWord("FOREIGN"),
ddlWord("KEY"),
ddlWord("INDEX"),
ddlWord("UNIQUE"),
2018-04-18 19:02:06 +02:00
ddlWord("CHECK")
])
2018-04-18 19:02:06 +02:00
ddlColumn = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlUnsigned, ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlFunctionWord("NOW"), ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments, ddlMathCond])))
ddlIfNotExists = Optional(Group(ddlWord("IF") + ddlWord("NOT") + ddlWord("EXISTS")).setResultsName("ifNotExists"))
createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlIfNotExists + ddlName.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create")
2016-06-09 19:19:37 +03:00
#ddlString.setDebug(True) #uncomment to debug pyparsing
2016-05-14 16:11:48 +03:00
ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable)
ddlComment = oneOf(["--", "#"]) + restOfLine
ddl.ignore(ddlComment)
# MAP SQL TYPES
types = {
'tinyint': 'tinyint',
'smallint': 'smallint',
2017-03-05 19:13:26 +01:00
'smallserial': 'smallint', # PostgreSQL
'int2': 'smallint', #PostgreSQL
'integer': 'integer',
'int': 'integer',
2017-03-05 19:13:26 +01:00
'serial': 'integer', # PostgreSQL
'int4': 'integer', #PostgreSQL
2017-04-25 22:23:59 +02:00
'mediumint' : 'integer',
'bigint': 'bigint',
2017-03-05 19:13:26 +01:00
'bigserial': 'bigint', # PostgreSQL
'int8': 'bigint', #PostgreSQL
'char': 'char_',
'varchar': 'varchar',
'character varying': 'varchar', #PostgreSQL
'text': 'text',
'clob': 'text',
2014-07-14 18:17:53 +02:00
'tinyblob': 'blob',
'blob': 'blob',
'mediumblob': 'blob',
'longblob': 'blob',
'bool': 'boolean',
'boolean': 'boolean',
'double': 'floating_point',
'float8': 'floating_point', # PostgreSQL
'float': 'floating_point',
'float4': 'floating_point', # PostgreSQL
'real': 'floating_point',
'numeric': 'floating_point', # PostgreSQL
'decimal' : 'floating_point', # MYSQL
2016-05-05 01:58:53 +03:00
'date': 'day_point',
'datetime': 'time_point',
'time': 'time_of_day',
'time without time zone': 'time_point', # PostgreSQL
'time with time zone': 'time_point', # PostgreSQL
2016-05-05 01:58:53 +03:00
'timestamp': 'time_point',
'timestamp without time zone': 'time_point', # PostgreSQL
'timestamp with time zone': 'time_point', # PostgreSQL
'timestamptz': 'time_point', # PostgreSQL
2016-05-05 01:58:53 +03:00
'enum': 'text', # MYSQL
2017-11-06 21:37:10 +01:00
'set': 'text', # MYSQL,
'longtext' : 'text', #MYSQL
'json' : 'text', # PostgreSQL
'jsonb' : 'text', # PostgreSQL
2017-11-06 21:37:10 +01:00
'tinyint unsigned': 'tinyint_unsigned', #MYSQL
'smallint unsigned': 'smallint_unsigned', #MYSQL
'integer unsigned': 'integer_unsigned', #MYSQL
'int unsigned': 'integer_unsigned', #MYSQL
'bigint unsigned': 'bigint_unsigned', #MYSQL
2018-07-15 19:35:58 +01:00
'mediumint unsigned' : 'integer', #MYSQL
}
2016-05-05 01:58:53 +03:00
if failOnParse:
ddl = OneOrMore(Suppress(SkipTo(createTable, False)) + createTable)
ddl.ignore(ddlComment)
try:
tableCreations = ddl.parseFile(pathToDdl)
except ParseException as e:
2016-05-05 23:08:27 +03:00
print(parseError + '. Exiting [-no-fail-on-parse]')
sys.exit(ERROR_STRANGE_PARSING)
2016-05-05 01:58:53 +03:00
else:
ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable)
ddl.ignore(ddlComment)
tableCreations = ddl.parseFile(pathToDdl)
2016-05-05 02:40:44 +03:00
if warnOnParse:
2016-05-05 23:08:27 +03:00
print(parseError + '. Continuing [-no-warn-on-parse]')
2016-05-05 01:58:53 +03:00
nsList = namespace.split('::')
2016-05-05 01:58:53 +03:00
def escape_if_reserved(name):
reserved_names = [
'BEGIN',
'END',
'GROUP',
'ORDER',
]
if name.upper() in reserved_names:
return '!{}'.format(name)
return name
2016-05-05 02:40:44 +03:00
# PROCESS DDL
2016-05-14 16:11:48 +03:00
tableCreations = ddl.parseFile(pathToDdl)
header = 0
if not splitTables:
header = beginHeader(pathToHeader, nsList)
2016-05-05 01:58:53 +03:00
DataTypeError = False
2016-03-19 17:31:40 +01:00
for create in tableCreations:
sqlTableName = create.tableName
if splitTables:
header = beginHeader(pathToHeader + sqlTableName + '.h', nsList)
tableClass = toClassName(sqlTableName)
tableMember = toMemberName(sqlTableName)
tableNamespace = tableClass + '_'
tableTemplateParameters = tableClass
print(' namespace ' + tableNamespace, file=header)
print(' {', file=header)
for column in create.columns:
if column.isConstraint:
continue
sqlColumnName = column[0]
columnClass = toClassName(sqlColumnName)
tableTemplateParameters += ',\n ' + tableNamespace + '::' + columnClass
columnMember = toMemberName(sqlColumnName)
sqlColumnType = column[1].lower()
2017-11-06 21:37:10 +01:00
if column.isUnsigned:
sqlColumnType = sqlColumnType + ' unsigned';
if sqlColumnType == 'timestamp' and timestampWarning:
print("Warning: timestamp is mapped to sqlpp::time_point like datetime")
print("Warning: You have to take care of timezones yourself")
print("You can disable this warning using -no-timestamp-warning")
columnCanBeNull = not column.notNull
print(' struct ' + columnClass, file=header)
print(' {', file=header)
print(' struct _alias_t', file=header)
print(' {', file=header)
print(' static constexpr const char _literal[] = "' + escape_if_reserved(sqlColumnName) + '";', file=header)
print(' using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
print(' template<typename T>', file=header)
print(' struct _member_t', file=header)
print(' {', file=header)
print(' T ' + columnMember + ';', file=header)
print(' T& operator()() { return ' + columnMember + '; }', file=header)
print(' const T& operator()() const { return ' + columnMember + '; }', file=header)
print(' };', file=header)
print(' };', file=header)
2016-03-21 21:38:37 +03:00
try:
2016-05-05 01:58:53 +03:00
traitslist = [NAMESPACE + '::' + types[sqlColumnType]]
2016-03-21 21:38:37 +03:00
except KeyError as e:
2016-03-22 00:27:01 +03:00
print ('Error: datatype "' + sqlColumnType + '"" is not supported.')
2016-03-21 21:48:11 +03:00
DataTypeError = True
requireInsert = True
column.hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == 'id')
if column.hasAutoValue:
2016-05-05 01:58:53 +03:00
traitslist.append(NAMESPACE + '::tag::must_not_insert')
traitslist.append(NAMESPACE + '::tag::must_not_update')
requireInsert = False
if not column.notNull:
2016-05-05 01:58:53 +03:00
traitslist.append(NAMESPACE + '::tag::can_be_null')
requireInsert = False
if column.hasDefaultValue:
requireInsert = False
if requireInsert:
2016-05-05 01:58:53 +03:00
traitslist.append(NAMESPACE + '::tag::require_insert')
2014-07-23 18:06:33 +02:00
print(' using _traits = ' + NAMESPACE + '::make_traits<' + ', '.join(traitslist) + '>;', file=header)
print(' };', file=header)
print(' } // namespace ' + tableNamespace, file=header)
print('', file=header)
print(' struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header)
print(' {', file=header)
print(' struct _alias_t', file=header)
print(' {', file=header)
print(' static constexpr const char _literal[] = "' + sqlTableName + '";', file=header)
print(' using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
print(' template<typename T>', file=header)
print(' struct _member_t', file=header)
print(' {', file=header)
print(' T ' + tableMember + ';', file=header)
print(' T& operator()() { return ' + tableMember + '; }', file=header)
print(' const T& operator()() const { return ' + tableMember + '; }', file=header)
print(' };', file=header)
print(' };', file=header)
print(' };', file=header)
if splitTables:
endHeader(header, nsList)
if not splitTables:
endHeader(header, nsList)
2016-03-21 21:48:11 +03:00
if (DataTypeError):
print("Error: unsupported datatypes." )
2016-03-22 00:27:01 +03:00
print("Possible solutions:")
print("A) Implement this datatype (examples: sqlpp11/data_types)" )
print("B) Extend/upgrade ddl2cpp (edit types map)" )
print("C) Raise an issue on github" )
2016-03-21 21:48:11 +03:00
sys.exit(10) #return non-zero error code, we might need it for automation