keywordDict.py:
# -*- coding:utf-8 -*-
# CREATED BY: bohuai jiang
# CREATED ON: 2019/9/18
# LAST MODIFIED ON:
# AIM: sort key words
from service.sql_parser_graph.units import ParseUnit
KEY_WEIGTH = {'SELECT': 10,
'INSERT': 10,
'WHERE': 8,
'FROM': 9,
'VALUES': 8,
'AND': 5,
',': 0}
class Keywordstack:
def __init__(self):
self.value = None
self.weight = -float('inf')
self.length = 0
def insert(self, value: ParseUnit) -> None:
self.length += 1
if value.name.upper() in KEY_WEIGTH.keys():
weight = KEY_WEIGTH[value.name.upper()]
else:
weight = -1
if weight > self.weight:
self.value = value
self.weight = weight
def pop(self):
return self.value
def reset(self):
self.value = None
self.weight = -float('inf')
self.length = 0
def is_empty(self):
if self.length > 0:
return False
else:
return True
知识兔SQLParser.py:
# created by bohuai jiang
# on 2019/7/23
# last modified on 2019/9/17 10:14
# -*- coding: utf-8 -*-
import sqlparse
from sqlparse.sql import Where, IdentifierList, Identifier, TokenList, Token, Parenthesis, Comment, Case, Operation, \
Function, Values
from service.sql_parser_graph.KeywrodDict import Keywordstack
from service.sql_parser_graph.units import ParseUnitList
from typing import Union, List, Optional,Tuple
import re
TAB_KEYWORD = ['FROM', 'LEFT JOIN', 'UPDATE', 'EXISTS', 'INNER JOIN', 'OUTER JOIN', 'JOIN', 'RIGHT JOIN', 'INTO']
COL_KEYWORD = ['INSERT', 'SELECT', 'WHERE', 'CASE', 'ON', 'AND', 'HAVING', 'OR', 'SET','WITH','BY','PRIOR']
IN_KEYWORD = ['IN']
ORDER_KEYWORD = ['ORDER BY','GROUP BY']
LIKE_KEYWORD = ['LIKE']
VALUE_KEYWORD = ['VALUES']
BETWEEN_KEYWORD = ['BETWEEN']
IS_KEYWORD = ['IS']
WHERE_EXCEPT = ['ROWNUMBER']
class SQLParser:
def __init__(self, sql: str, **kwargs) -> None:
self.exception_list = kwargs['exception'] if 'exception' in kwargs.keys() else []
self.has_where = False
self._origin_sql = sql
sql = self.sql_interpreter(sql)
tokens = sqlparse.parse(sql)
if len(tokens) > 1:
raise Exception("sql is not single")
else:
self._stmt = tokens[0]
self.re_get_elements()
self._sql_text = sql
def re_get_elements(self, where_only: bool = False):
self.elements = ParseUnitList()
self.lu_parse(self._stmt, add_opt=(not where_only))
self.elements.build_relation()
return self.elements
@property
def tokens(self) -> Union[TokenList, Token]:
return self._stmt
def sql_statement(self):
"""
sql statement property
:return:
"""
return self._stmt.get_type().upper().strip()
def _is_function_contain_keyword(self, function: Token, keyword_list: List[str]) -> Optional[str]:
if isinstance(function, Function):
if function.tokens[0].value.upper() in keyword_list:
return function.tokens[0].value.upper()
return None
def get_fist_keyword(self, statement: Token) -> str:
keyword = ''
for token in self.token_walk(statement, True, False):
if token.is_keyword:
keyword = token.value.upper()
return keyword
return keyword
def lu_parse(self, statement: TokenList, level: int = 0, t_idx: int = 3, **kwargs) -> None:
type_name = ['COL', 'TAB', 'SUB', 'IN', 'STRUCT']
add_opt = kwargs['add_opt'] if 'add_opt' in kwargs else True
parents = kwargs['parents'] if 'parents' in kwargs else []
build_relation = kwargs['build_relation'] if 'build_relation' in kwargs else True
keyword_capture = kwargs['keyword'] if 'keyword' in kwargs else ''
is_where = kwargs['is_where'] if 'is_where' in kwargs else False
order_by_loop = False
if not isinstance(statement,TokenList):
statement = [statement]
for i, t in enumerate(statement):
#print(i, t)
v = self._is_function_contain_keyword(t, COL_KEYWORD + TAB_KEYWORD + IN_KEYWORD + ORDER_KEYWORD)
if t.value in self.exception_list:
t.is_keyword = False
t.ttype = sqlparse.tokens.Name
t = Identifier([t])
if (t.is_keyword or v is not None):
keyword_capture = t.normalized
if v is not None:
keyword_capture = v
if keyword_capture in COL_KEYWORD:
t_idx = 0
elif keyword_capture in TAB_KEYWORD or 'JOIN' in keyword_capture:
t_idx = 1
elif keyword_capture in IN_KEYWORD:
t_idx = 3
elif keyword_capture in ORDER_KEYWORD:
t_idx = 0
self.elements.add_order(tokens=t, key=keyword_capture, parents=parents, level=level,
is_where=is_where)
order_by_loop = True
continue
elif keyword_capture in BETWEEN_KEYWORD:
t_idx = 0
self.elements.add_between(tokens=t, key=keyword_capture, parents=parents, level=level,
is_where=is_where)
continue
elif keyword_capture in LIKE_KEYWORD:
t_idx = 0
self.elements.add_like(tokens=t, key=keyword_capture, parents=parents, level=level,
is_where=is_where)
continue
elif keyword_capture in IS_KEYWORD:
t_idx = 0
self.elements.add_is(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where)
continue
else:
t_idx = 4
# --- DATA Correction --- #
if not t.is_whitespace and not isinstance(t, Comment) and 'Comment' not in str(t.ttype):
# -- add operation --#
if isinstance(t, Case) or isinstance(t, Operation) or order_by_loop or isinstance(t, Values):
order_by_loop = False
self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level,
keyword=keyword_capture)
continue
if isinstance(t, Where):
count_valid = 0
for tt in self.token_walk(t, yield_current_token=False):
if str(tt.ttype) == 'Token.Name' or tt.value.upper() in WHERE_EXCEPT:
count_valid += 1
if count_valid > 0:
self.has_where = True
self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level,
is_where=True)
continue
if not isinstance(t, IdentifierList):
# print(' value :', t)
rest = self.elements.add(t, type_name[t_idx], parents=parents, key=keyword_capture,
is_where=is_where, level=level)
if t_idx == 3:
t_idx = 0
if keyword_capture == 'INTO' and \
self.elements.by_id[len(self.elements.by_id)-1].type=='TAB':
t_idx = 0
# print('after :', token_id,'\n')
# -- subquery -- #
if rest is not None:
for rest_v in rest:
sub_parents = rest_v['parents']
rest_tokens = rest_v['tokens']
if build_relation:
self.elements.build_relation()
for rest_t in rest_tokens:
# print('sub_parent: ',token_id-1, 'sub_value: ', rest, '\n')
if isinstance(rest_t, Parenthesis):
first_keyword = self.get_fist_keyword(rest_t)
level_ = level + 1 if 'SELECT' == first_keyword else level
self.lu_parse(statement=rest_t, parents=sub_parents, t_idx=t_idx,
add_opt=add_opt, build_relation=True, level=level_, is_where=is_where)
else:
self.lu_parse(t, t_idx=t_idx, build_relation=False, parents=parents, add_opt=add_opt,
keyword=keyword_capture)
def sql_reconstruct(self):
units = self.elements.by_id()
for id in units:
pass
def display_elements(self) -> None:
for v in self.elements:
print(v)
# ---------#
def get_table_name(self, alise_on = False) -> Union[Tuple[List[str],List[str]],List[str]]:
tab_names = []
as_names = []
for tab in self.elements.by_type['TAB']:
if '(' not in tab.name and 'DUAL' not in tab.name:
tab_names.append(tab.name)
if tab.as_name != 'DUMMY':
as_names.append(tab.as_name)
else:
as_names.append(tab.name)
if alise_on:
return tab_names, as_names
else:
return tab_names
def token_walk(self, token, topdown=True, yield_current_token=True):
"""
walk all token
:param token:
:param topdown:
:return:
"""
def __has_next_token(t):
return hasattr(t, "tokens")
if yield_current_token:
yield token
for idx, tk in enumerate(token):
if __has_next_token(tk) and topdown:
for x in self.token_walk(tk, topdown, yield_current_token):
yield x
else:
yield tk
def sql_interpreter(self, sql: str) -> str:
# -- 1. remove comments --#
sql = sql.strip()
ends = len(sql)
for i in range(len(sql))[::-1]:
if sql[i] not in ['\n', '\t', ' ', ';']:
ends = i + 1
break
sql = sql[0:ends]
sql = sql.replace('(+)', '')
sql = re.sub(re.compile("/\*.*?\*/", re.DOTALL), "", sql)
sql = re.sub(r'(?m)^ *--.*\n?', '', sql)
sql = re.sub('\s+', ' ', sql).strip()
tokens = sqlparse.parse(sql)
# -- 2. split sql --#
sql = ''
pre_keyword = None
wirte = False
for t in self.token_walk(tokens, True, False):
if t.value == '(':
if wirte:
sql += ','
wirte = False
sql += t.value.upper()
else:
sql += t.value.upper()
if t.is_keyword:
if t.value.upper() == 'INTO' and pre_keyword == 'INSERT':
wirte = True
pre_keyword = t.value.upper()
return sql
def reconstruct(self) -> str:
n_line_elem = ['SELECT', 'FROM', 'WHERE', 'ORDER BY']
without_space = [',', '(', ')']
keys_list = sorted(self.elements.by_id.keys())
out = ''
buffer = Keywordstack()
for i in keys_list:
unit = self.elements.by_id[i]
if unit.type != 'SUB':
level = unit.level
value = self.elements.by_id[i].token.value
if unit.type == 'STRUCT' and value not in [')', '(']:
buffer.insert(unit)
else:
if not buffer.is_empty():
unit_ = buffer.pop()
buffer.reset()
level_ = unit.level
value_ = unit_.name
if value_ in n_line_elem:
value_ = '\n' + '\t' * level_ + value_
if value_ in without_space:
out = out[0:-1] if out[-1] == ' ' else out
out += value_
else:
out += value_ + ' '
if value in n_line_elem:
value = '\n' + '\t' * level + value
if value in without_space:
out = out[0:-1] if out[-1] == ' ' else out
out += value
else:
out += value + ' '
return out
if __name__ == "__main__":
address = '/home/yohoo/PycharmProjects/Common/'
sql_text = '''
select
`TYPE` as `type`,
FILE_PATH as filePath,
DISPLAY_NAME as `name`,
SEQUENCE as sequence
from bid_document
where BID_ID = 3 and IS_VALID = 1 and rownum > 10000
order by SEQUENCE
'''
sql_text = '''
select * from tab where rownum > 1000
'''
sp = SQLParser(sql_text, exception=['SORT'])
sp.display_elements()
out = sp.reconstruct()
print(out)
print(sp.has_where)
print('table_names :', sp.get_table_name())
知识兔units.py:
# created by bohuai jiang
# on 2019/7/23
# last modified on 2019/9/17 10:14
# -*- coding: utf-8 -*-
from sqlparse.sql import Statement, Comment, Where, Identifier, IdentifierList, Parenthesis, Function, \
Comparison, Operation, Token, TokenList, Values
from typing import Union, List, Tuple, Optional, Set
class ParseUnit:
def __init__(self):
self.id = None
self._name = None # sql code name
self._as_name = None # as what name
self._from_name = None # from where
self._type = None # TAB-table , COL-column, SUB-subquery ,OPT- >,<,=.., FUNC-MAX,SUM..
self._keyword = None
self._in_statement = 'OTHER'
# self._opt = None
self._parent = set()
self._edges = set()
self._level = 0
self.token = None
@property
def in_statement(self) -> str:
return self._in_statement
@property
def level(self) -> int:
return self._level
@property
def keyword(self) -> str:
return self._keyword
@property
def name(self) -> str:
return self._name
@property
def as_name(self) -> str:
return self._as_name
@property
def from_name(self) -> str:
return self._from_name
@property
def parent(self) -> set:
return self._parent
@property
def type(self) -> str:
return self._type
@property
def edges(self) -> set:
return self._edges
# @property
# def opt(self) -> str:
# return self._opt
@keyword.setter
def keyword(self, key: str):
self._keyword = key.upper()
@level.setter
def level(self, level: int):
self._level = level
@name.setter
def name(self, name: Optional[str]):
if type(name) == str:
self._name = name.upper()
else:
self._name = name
@as_name.setter
def as_name(self, as_name: str):
self._as_name = as_name.upper()
@from_name.setter
def from_name(self, from_name: str):
self._from_name = from_name.upper()
@parent.setter
def parent(self, parent: Set['ParseUnit']):
self._parent = parent
# @opt.setter
# def opt(self, opt: str):
# self._opt = opt
@type.setter
def type(self, type: str):
if type not in ['COL', 'TAB', 'SUB', 'OPT', 'FUNC', 'STRUCT', 'VALUE']:
raise ValueError('type must be either one of following [COL, TAB, SUB, OPT, FUNC, STRUC, VALUE]')
self._type = type.upper()
@in_statement.setter
def in_statement(self, state: str):
if state not in ['WHERE', 'ORTHER']:
raise ValueError('type must be either one of following [WHERE, OTHER]')
self._in_statement = state
@edges.setter
def edges(self, edges: Set['ParseUnit']):
self._edges = edges
def overwrite(self, unit: 'ParseUnit'):
if unit.name is not None:
self._name = unit.name
if unit.as_name is not None:
self._as_name = unit.as_name
if unit.from_name is not None:
self._from_name = unit.from_name
if unit.parent is not None:
self._parent = unit.parent
if unit.type is not None:
self._type = unit.type
if not unit.edges:
self._edges = unit.edges
def inherit(self, unit: 'ParseUnit', update_edges: bool = False):
self._name = unit.name
self._as_name = unit.as_name
if unit.from_name != 'DUMMY':
self._from_name = unit.from_name
self._type = unit.type
if update_edges:
self._edges.add(unit.id)
def show(self) -> str:
out = ''
if self._from_name is not 'DUMMY' and not None:
out += self._from_name + '.'
out += self._name
if self._as_name is not 'DUMMY' and not None:
out += ' as ' + self._as_name
return out
def add_parents(self, parents: Union[List[int], Set[int]]) -> None:
for p in parents:
self._parent.add(p)
def __repr__(self):
out = '%s\n' % str(self.id)
out += '\ttype:%s\n' % self.type
out += '\tname:%s\n' % self.name
out += '\tkeyword:%s\n' % self.keyword
out += '\tstatement:%s\n' % self.in_statement
out += '\tlevel:%s\n' % self.level
out += '\tas_name:%s\n' % self.as_name
out += '\tfrom' + (' tab ' if self.type == 'COL' else '') + ':%s\n' % self.from_name
out += '\tparent:%s\n' % str(self.parent)
out += '\tedges:%s\n' % str(self.edges)
return out
class ParseUnitList:
def __init__(self) -> None:
# -- tab col relation -- #
self.by_type = {'COL': [],
'TAB': [],
'SUB': [],
'OPT': [],
'FUNC': [],
'STRUCT': [],
'VALUE': []}
self.by_id = dict() # G
self._allow_sub_has_table = False
def __insert(self, unit: ParseUnit) -> int:
# o(mn) m<n
id = len(self.by_id)
unit.id = id
# for i, each_units in enumerate(self.by_type[unit.type]):
# as_name = each_units.as_name
# if unit.name == as_name and (unit.from_name == each_units.from_name \
# or each_units.from_name == 'DUMMY'):
# unit.inherit(unit=each_units, update_edges=True)
# each_units.inherit(unit=unit)
# self.by_id[each_units.id] = each_units
# break
# -----#
self.by_type[unit.type].append(unit)
self.by_id[unit.id] = unit
return id
def __update_by_type(self) -> None:
for key in ['SUB', 'TAB', 'OPT', 'FUNC', 'COL', 'STRUCT']:
for unit in self.by_type[key]:
self.by_id[unit.id] = unit
def __update_by_id(self):
self.by_type = {'COL': [],
'TAB': [],
'SUB': [],
'OPT': [],
'FUNC': [],
'STRUCT': [],
'VALUE': []}
for id in self.by_id:
unit = self.by_id[id]
self.by_type[unit.type].append(unit)
########################################
# add function #
########################################
# ----------- add by token type -----------#
def _add_Identifier(self, tokens: Token, type: str, key: str, level: int, is_where: bool,
parents: List[int] = None) -> Tuple[int, Union[Token, TokenList]]:
out = ParseUnit()
if '(' in tokens.value and tokens.value != '(':
out.type = 'SUB'
else:
out.type = type
out.keyword = key
out.level = level
dot_flag = 1
out.token = tokens
if is_where:
out.in_statement = 'WHERE'
if parents is not None and parents != []:
out.add_parents(parents)
abnormal = None
try:
for t in tokens:
if str(t.ttype).upper() == 'TOKEN.PUNCTUATION' and t.value == '.':
dot_flag += 1
continue
if str(t.ttype).upper() == 'TOKEN.NAME':
if dot_flag % 2 == 0:
out.name = t.value
dot_flag += 1
else:
out.from_name = t.value
if t.ttype is None:
out.as_name = t.value
if not isinstance(t, Identifier):
abnormal = t
if dot_flag <= 1:
out.name = out.from_name
out.from_name = 'DUMMY'
except:
out.name = tokens.value
# --- double check whether used dot --- #
if out.as_name is None:
out.as_name = 'DUMMY'
# -- patch --#
if out.name is None:
if abnormal is not None:
out.name = abnormal.value
else:
out.name = out.as_name
# -- add order by or group by -- #
keyList = ['ORDER BY', 'GROUP BY']
if key in keyList:
# -- find nearest opt -- #
for id in range(len(self.by_id))[::-1]:
acquire_id = id
unit = self.by_id[id]
if unit.type == 'OPT' and unit.name == key:
break
out.parent.add(acquire_id)
# -- add to like -- #
if key == 'LIKE':
out.add_parents([len(self.by_id) - 1])
id = self.__insert(out)
return id, abnormal
def _add_Comparison(self, tokens: Comparison, type: str, key: str, level: int, is_where: bool,
parents: List[int] = None) \
-> Optional[List[dict]]:
# -- get opt unit --#
opt = None
for t in tokens:
if str(t.ttype).upper() == 'TOKEN.OPERATOR.COMPARISON':
opt = t.value
unit = ParseUnit()
unit.name = opt
unit.type = 'OPT'
unit.keyword = key
unit.level = level
count = 0
for t in tokens:
if not t.is_whitespace:
count += 1
if count == 2:
unit.token = t
break
expect_id = len(self.by_id) + 1
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
# -- left unit -- #
parents = [expect_id]
parents_token_left = self.add(tokens=tokens.left, type=type, key=key, level=level, parents=parents,
is_where=is_where)
self.__insert(unit)
parents_token_right = self.add(tokens=tokens.right, type=type, key=key, level=level, parents=parents,
is_where=is_where)
# unit.edges.add(left_v)
# unit.edges.add(right_v)
if parents_token_left and parents_token_right:
return parents_token_left + parents_token_right
elif parents_token_left:
return parents_token_left
else:
return parents_token_right
def _add_Operation(self, tokens: Operation, type: str, key: str, level: int, is_where: bool,
parents: List[int] = None):
unit = ParseUnit()
unit.name = tokens.value
unit.type = 'OPT'
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
expect_id = len(self.by_id)
self.__insert(unit)
for t in tokens.tokens:
self.add(tokens=t, type=type, key=key, level=level, parents=[expect_id],
is_where=is_where)
def _add_Function(self, tokens: Function, key: str, level: int, is_where: bool, parents: List[int] = None) \
-> Tuple[int, Optional[list]]:
unit = ParseUnit()
unit.name = tokens.tokens[0].value
unit.type = 'FUNC'
unit.keyword = key
unit.level = level
unit.token = tokens.tokens[0]
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
id = self.__insert(unit)
return id, tokens.tokens[1::]
def _add_Parenthesis(self, tokens: Parenthesis, key: str, level: int, is_where: bool, parents: List[int] = None) \
-> Tuple[int, Parenthesis]:
unit = ParseUnit()
unit.name = tokens.value
unit.type = 'SUB'
unit.keyword = key
unit.level = level
unit.from_name = 'DUMMY'
unit.as_name = 'DUMMY'
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
id = self.__insert(unit)
return id, tokens
def add(self, tokens: Union[Token, TokenList], type: str, is_where: bool, key: str, level: int, \
parents: List[int] = None) -> Optional[List[dict]]:
if isinstance(tokens, Identifier):
id, abnormal = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level,
is_where=is_where)
if abnormal is not None:
if isinstance(abnormal, Function):
return self.add(tokens=abnormal, type=type, parents=[id], key=key, level=level,
is_where=is_where)
else:
return [{'parents': [id], 'tokens': [abnormal]}]
else:
return None
elif isinstance(tokens, Comparison):
abnormal = self._add_Comparison(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where,
level=level)
return abnormal
elif isinstance(tokens, Function):
id, token_list = self._add_Function(tokens=tokens, parents=parents, key=key, level=level,
is_where=is_where)
return [{'parents': [id], 'tokens': token_list}]
elif isinstance(tokens, Parenthesis):
id, token = self._add_Parenthesis(tokens=tokens, parents=parents, key=key, level=level,
is_where=is_where)
return [{'parents': [id], 'tokens': [token]}]
elif isinstance(tokens, Values):
rest = self._add_value(tokens=tokens, level=level, parents=parents, is_where=is_where)
return rest
elif isinstance(tokens, Operation):
self._add_Operation(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where, level=level)
elif tokens.value.upper() == 'IN':
self._add_In(tokens=tokens, is_where=is_where, key=key, level=level, parents=parents)
else:
type = 'STRUCT' if str(tokens.ttype[0]) not in ['Literal', 'Number'] else 'VALUE'
id, token_list = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level,
is_where=is_where)
if token_list is not None:
self.add(tokens=token_list, type=type, is_where=is_where, key=key, level=level, parents=[id])
return None
# ----------- add by keywords ----------- #
def _add_In(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> None:
# acquire id
cur_id = len(self.by_id)
left_id = cur_id - 1
right_id = cur_id + 1
# --build in Node -#
unit = ParseUnit()
unit.name = 'IN'
unit.type = 'OPT'
unit.edges = {left_id, right_id}
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
self.__insert(unit)
left = self.by_id[left_id]
left.parent.add(cur_id)
def add_order(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
List[dict]]:
next_id = len(self.by_id) + 1
unit = ParseUnit()
unit.name = tokens.value
unit.type = 'OPT'
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
unit.edges.add(next_id)
self.__insert(unit)
return None
def add_like(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
List[dict]]:
pre_id = len(self.by_id) - 1
unit = ParseUnit()
unit.name = tokens.value
unit.type = 'OPT'
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
unit.edges.add(pre_id)
self.__insert(unit)
return None
def add_between(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
List[dict]]:
id_pre = len(self.by_id) - 1
unit = ParseUnit()
unit.name = 'BETWEEN'
unit.type = 'OPT'
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
unit.edges.add(id_pre)
# unit.edges.add(id_n_left)
# unit.edges.add(id_n_right)
self.__insert(unit)
return None
def _add_value(self, tokens: Token, level: int, is_where: bool, parents: List[int] = None) -> Optional[List[dict]]:
self._allow_sub_has_table = True
col_id = len(self.by_id) - 1
unit = ParseUnit()
unit.name = 'VALUES'
unit.type = 'OPT'
unit.keyword = 'VALUES'
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
unit.edges = {col_id, col_id + 2}
self.__insert(unit)
out = []
for t in tokens.tokens[1::]:
if isinstance(t, Parenthesis):
p, tokens = self._add_Parenthesis(tokens=t, key='VALUES', level=level, parents=[col_id + 1],
is_where=is_where)
id = self.__insert(p)
out.append({'parents': [id], 'tokens': [tokens]})
return out
def add_is(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
List[dict]]:
pre_id = len(self.by_id) - 1
unit = ParseUnit()
unit.name = tokens.value
unit.type = 'OPT'
unit.keyword = key
unit.level = level
unit.token = tokens
if is_where:
unit.in_statement = 'WHERE'
if parents is not None and parents != []:
unit.add_parents(parents)
unit.edges.add(pre_id)
self.__insert(unit)
return None
#########################################
def __iter__(self):
return iter(self.by_id.values())
#########################################
# build relation function #
#########################################
def build_relation(self):
# --- build parents ---#
symbol_idx = dict() # {as_name/id: [index]}
idx_edges = dict() # {id : [index]}
for key in self.by_id.keys():
idx_edges[key] = set()
check_keys = ['COL'] if not self._allow_sub_has_table else ['COL', 'SUB']
# -- buil tab col relation --#
for key in ['SUB', 'TAB', 'COL']:
for unit in self.by_type[key]:
key_i = unit.type
# -- add edges -- #
if len(unit.parent) > 0:
for p in unit.parent:
idx_edges[p].add(unit.id)
if key_i == 'TAB':
symbol = unit.as_name
if symbol not in symbol_idx.keys():
symbol_idx[symbol] = [unit.id]
else:
symbol_idx[symbol].append(unit.id)
if unit.name not in symbol_idx.keys():
symbol_idx[unit.name] = [unit.id]
else:
symbol_idx[unit.name].append(unit.id)
# -- update parents --#
if key_i in check_keys:
if unit.from_name != 'DUMMY':
try:
parent_indexes = symbol_idx[unit.from_name]
except:
parent_indexes = []
# raise SQLGrammarError('invalid column: ' + unit.name)
for parent in parent_indexes:
unit.parent.add(parent)
idx_edges[parent].add(unit.id)
else:
all_parents = self.add_all_parents(unit.level)
if len(all_parents) == 1:
parent = self.by_id[all_parents.pop()]
as_name = parent.as_name if parent.as_name != 'DUMMY' else parent.name
unit.from_name = as_name
unit.add_parents(self.add_all_parents(unit.level))
for p in unit.parent:
idx_edges[p].add(unit.id)
self.__update_by_type()
# --- build parents ---#
between_count = None
blevel = None
b_id = None
for id in self.by_id.keys():
unit = self.by_id[id]
edges = unit.edges
# -- between handler --#
if between_count is not None:
if blevel == unit.level:
between_count += 1
unit.parent.add(b_id)
if blevel == 3:
between_count = None
blevel = None
b_id = None
if unit.name == 'BETWEEN':
between_count = 0
blevel = unit.level
b_id = unit.id
for ed in edges:
self.by_id[ed].parent.add(id)
# --- build edges --- #
for id in self.by_id:
parents = self.by_id[id].parent
for pa in parents:
self.by_id[pa].edges.add(id)
self._allow_sub_has_table = False
self.__update_by_id()
def add_all_parents(self, level: int) -> Set[int]:
parents = set()
for key in ['TAB']:
for unit in self.by_type[key]:
if unit.level == level:
parents.add(unit.id)
return parents
def build_relation_by_tab_info(self):
pass
############################################
# graph search #
############################################
def find_root(self, graph: ParseUnit, col_only: bool = False) -> Optional[List[int]]:
root = []
path = []
q = [graph.id]
while q:
v = q.pop(0)
if not v in path:
if type(v) == int:
path = path + [v]
units = self.by_id[v]
if col_only:
if units.type == 'COL' and '(' not in units.name:
root.append(units.id)
else:
if len(units.edges) == 0:
root.append(units.id)
q = q + list(units.edges)
return root
def find_tab(self, colum: ParseUnit, tab_only: bool = False) -> Optional[List[int]]:
tabs = []
path = []
q = [colum.id]
while q:
v = q.pop(0)
if not v in path:
path = path + [v]
# --- #
units = self.by_id[v]
if tab_only:
if units.type == 'TAB':
if units.id not in tabs:
tabs.append(units.id)
else:
if len(units.parent) == 0:
if units.id not in tabs:
tabs.append(units.id)
# ---#
q = q + list(units.parent)
return tabs
############################
# remove node #
############################
def remove(self, id_list: List[int]):
all_list = []
for id in id_list:
trunk_id = self._get_remove_trunk(self.by_id[id])
all_list.extend(trunk_id)
for id in all_list:
del self.by_id[id]
def _get_remove_trunk(self, unit: ParseUnit) -> List[int]:
id_list = []
target_level = unit.level
path = []
q = [unit.id]
while q:
v = q.pop(0)
if not v in path:
if type(v) == int:
path = path + [v]
# --- #
units_ = self.by_id[v]
c_level = units_.level
if units_.type != 'TAB' and units_.type != 'SUB':
id_list.append(units_.id)
q = q + list(units_.parent) + list(units_.edges)
else:
if c_level > target_level:
id_list.append(units_.id)
return id_list
class SQLGrammarError(Exception):
pass
知识兔