#!/usr/bin/env python ## # Copyright (c) 2013-2014, Roland Bock # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. ## from __future__ import print_function import sys import re import os from pyparsing import CaselessLiteral, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, \ WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, Or, Group INCLUDE = 'sqlpp11' NAMESPACE = 'sqlpp' # HELPERS def get_include_guard_name(namespace, inputfile): val = re.sub("[^A-Za-z]+", "_", namespace + '_' + os.path.basename(inputfile)) return val.upper() def repl_func(m): if (m.group(1) == '_'): return m.group(2).upper() else: return m.group(1) + m.group(2).upper() def toClassName(s): return re.sub("(^|\s|[_0-9])(\S)", repl_func, s) def toMemberName(s): return re.sub("(\s|_|[0-9])(\S)", repl_func, s) # PARSER def ddlWord(string): return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_") ddlString = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")]) ddlNum = Word(nums + ".") ddlTerm = Word(alphas, alphanums + "_$") ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")" ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull") ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue"); ddlAutoValue = ddlWord("AUTO_INCREMENT").setResultsName("hasAutoValue"); ddlColumnComment = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment") ddlConstraint = Or([ ddlWord("CONSTRAINT"), ddlWord("PRIMARY"), ddlWord("FOREIGN"), ddlWord("KEY"), ddlWord("INDEX"), ddlWord("UNIQUE"), ]) ddlColumn = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments]))) createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlTerm.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create") ddl = ZeroOrMore(SkipTo(createTable, True)) ddlComment = oneOf(["--", "#"]) + restOfLine ddl.ignore(ddlComment) # MAP SQL TYPES types = { 'tinyint': 'tinyint', 'smallint': 'smallint', 'integer': 'integer', 'int': 'integer', 'bigint': 'bigint', 'char': 'char_', 'varchar': 'varchar', 'text': 'text', 'tinyblob': 'blob', 'blob': 'blob', 'mediumblob': 'blob', 'longblob': 'blob', 'bool': 'boolean', 'double': 'floating_point', 'float': 'floating_point', } # PROCESS DDL if (len(sys.argv) != 4): print('Usage: ddl2cpp ') sys.exit(1) pathToDdl = sys.argv[1] pathToHeader = sys.argv[2] + '.h' namespace = sys.argv[3] ddlFile = open(pathToDdl, 'r') header = open(pathToHeader, 'w') print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header) print('#define '+get_include_guard_name(namespace, pathToHeader), file=header) print('', file=header) print('#include <' + INCLUDE + '/table.h>', file=header) print('#include <' + INCLUDE + '/column_types.h>', file=header) print('', file=header) print('namespace ' + namespace, file=header) print('{', file=header) tableCreations = ddl.parseFile(pathToDdl) for tableCreation in tableCreations: sqlTableName = tableCreation.create.tableName tableClass = toClassName(sqlTableName) tableMember = toMemberName(sqlTableName) tableNamespace = tableClass + '_' tableTemplateParameters = tableClass print(' namespace ' + tableNamespace, file=header) print(' {', file=header) for column in tableCreation.create.columns: if column.isConstraint: continue sqlColumnName = column[0] columnClass = toClassName(sqlColumnName) tableTemplateParameters += ',\n ' + tableNamespace + '::' + columnClass columnMember = toMemberName(sqlColumnName) sqlColumnType = column[1].lower() columnCanBeNull = not column.notNull print(' struct ' + columnClass, file=header) print(' {', file=header) print(' struct _name_t', file=header) print(' {', file=header) print(' static constexpr const char* _get_name() { return "' + sqlColumnName + '"; }', file=header) print(' template', file=header) print(' struct _member_t', file=header) print(' {', file=header) print(' T ' + columnMember + ';', file=header) print(' T& operator()() { return ' + columnMember + '; }', file=header) print(' const T& operator()() const { return ' + columnMember + '; }', file=header) print(' };', file=header) print(' };', file=header) #print(sqlColumnType) print(' using _value_type = ' + NAMESPACE + '::' + types[sqlColumnType] + ';', file=header) print(' struct _column_type', file=header) print(' {', file=header) requireInsert = True if column.hasAutoValue: print(' using _must_not_insert = std::true_type;', file=header) print(' using _must_not_update = std::true_type;', file=header) requireInsert = False if not column.notNull: print(' using _can_be_null = std::true_type;', file=header) requireInsert = False if column.hasDefaultValue: requireInsert = False if requireInsert: print(' using _require_insert = std::true_type;', file=header) print(' };', file=header) print(' };', file=header) print(' }', file=header) print('', file=header) print(' struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header) print(' {', file=header) print(' using _value_type = ' + NAMESPACE + '::no_value_t;', file=header) print(' struct _name_t', file=header) print(' {', file=header) print(' static constexpr const char* _get_name() { return "' + sqlTableName + '"; }', file=header) print(' template', file=header) print(' struct _member_t', file=header) print(' {', file=header) print(' T ' + tableMember + ';', file=header) print(' T& operator()() { return ' + tableMember + '; }', file=header) print(' const T& operator()() const { return ' + tableMember + '; }', file=header) print(' };', file=header) print(' };', file=header) print(' };', file=header) print('}', file=header) print('#endif', file=header)