python实现三叉树

网友投稿 271 2022-11-29

python实现三叉树

最近用python实现了三叉树,觉得挺有意思的,所以分享出来:

class Node: """The Node class. You should not change this!""" def __init__(self,ID,data): self.ID=ID self.data=data self.children = [] # This is a list of other Node objects. At first it is empty.def getID(tree): return tree.IDPRINT_NODE_LIST = []def printNode(tree): print(getID(tree)) print(getData(tree)) PRINT_NODE_LIST.append(getID(tree))def getData(tree): return tree.data

ef buildTree(data, split_list, root_id): """ Build a tree using any tree data structure that you choose. The root node of this (sub) tree should store all of data. If the data is empty or the split_list is empty, len(data) == 0 or len(split_list) == 0, the this root node will have no children and it should still store the id and the data (even though it is empty). If the data is not empty and the split_list is not empty, the root node of this (sub) tree should have three children. The children should all be tree stuctures (just like the root). The three children should each be given a subset of data. Let s = split_list[0], the first integer in split_list. Child index 0: Should contain a NumPy array with shape (N_0, M) that containing all points in data where data[i,s] < -1. N_0 is the number of points that fit this criteria. If N_0 is zero, this child should store an empty NumPy array with shape (0, M), e.g. np.zeros((0,M)). This child should have id = root_id*10 Child index 1: The same as child index 0 except: Child data is -1 <= data[i,s] <= 1 Child id = root_id*10 + 1 Child index 2: The same as child index 2 except: Child data is data[i,s] > 1 Child id = root_id*10 + 2 The tree should continue growing where the children split their data based on split_list[1] and the grandchildren split their data based on split_list[2] and so on. Input: data: NumPy ndarray shape (N, M) representing the M-dimensional coordinates of N data points. The data may be an empty NumPy array, i.e. len(data) == 0. split_index_order: List of integers. Each integer in the list is in the range [0,M). The list will have at most M entries. The list may be empty, [], in which case, this tree will not have any children. root_id: Positive integer representing the ID for the root of this (sub) tree. The ID for the root any child subtrees should be root_id*10 + index_of_that_child. So if the root_id is 7 and there are three children, the IDs of the children should be 70, 71, 72. Return: A data structure of your choosing that represents the resulting tree. """ ### BEGIN YOUR CODE ### # create the root data root = Node(ID=root_id,data=data) curr_layer = [root] idx=0 if(len(data) == 0 or len(split_list) == 0): return root while(True): # get the split index index=split_list[idx] # store the new layer node temp=[] for node in curr_layer: # print('hello') cur_root=node cur_data=cur_root.data if(len(cur_data)==0): continue # select the column data arr=cur_data[:,index] child_0=[] child_1=[] child_2=[] for i,item in enumerate(arr): if(item<-1): child_0.append(i) elif(item>=-1 and item<=1): child_1.append(i) else: child_2.append(i) # create children node for i in range(3): if(i==0 and len(child_0)!=0): arr1=cur_data[child_0,:] elif(i==1 and len(child_1)!=0): arr1=cur_data[child_1,:] elif(i==2 and len(child_2)!=0): arr1=cur_data[child_2,:] else: arr1=np.zeros((0,cur_data.shape[-1])) child_id=cur_root.ID*10+i child = Node(child_id,arr1) cur_root.children.append(child) temp.append(child) idx+=1 curr_layer=[item for item in temp] # print('hello') if(idx>=len(split_list)): break return root

def printTreeBF(tree): list_node=[tree] while(len(list_node)>0): cur_node=list_node[0] list_node.pop(0) printNode(cur_node) # print(cur_node.children) for item in cur_node.children: list_node.append(item)def printTreeDF(tree): printNode(tree) for item in tree.children: printTreeDF(item)

感觉还蛮有意思的,我这里仅展示核心代码哈。

测试用例也写一个:

def task1_testC(data): tree1C = buildTree(data, [0, 1, 2], 1) print("tree1C root") printNode(tree1C) tree1C_firstChild = getChildren(tree1C)[0] print("tree1C first child") printNode(tree1C_firstChild) tree1C_firstGrandchild = getChildren(tree1C_firstChild)[0] print("tree1C first grandchild") printNode(tree1C_firstGrandchild) data1 = np.array([ [-1.5, -0.5, -0.2], [0.3, 1.3, 0.0], [-1.3, -1.4, -2.1], [0.9, 1.5, -0.6]])task1_testC(data1)

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:TensorFlow Object Detection API教程——利用自己制作的数据集进行训练预测和测试
下一篇:ubuntu 16.04 anaconda 切换清华的软件源安装pytorch
相关文章

 发表评论

暂时没有评论,来抢沙发吧~