c语言sscanf函数的用法是什么
253
2022-09-14
python r Pruned Dependency Trees
今天在调试Graph Convolution over Pruned Dependency Trees Improves Relation Extraction代码的是时候,想搞清楚依赖树是怎么构成的,我特地给Tree.py写了一个测试用例,代码的地址为:
operations on trees."""import numpy as npfrom collections import defaultdictclass Tree(object): """ Reused tree object from stanfordnlp/treelstm. """ def __init__(self): self.parent = None self.num_children = 0 self.children = list() def add_child(self,child): child.parent = self self.num_children += 1 self.children.append(child) def size(self): if getattr(self,'_size'): return self._size count = 1 for i in range(self.num_children): count += self.children[i].size() self._size = count return self._size def depth(self): if getattr(self,'_depth'): return self._depth count = 0 if self.num_children>0: for i in range(self.num_children): child_depth = self.children[i].depth() if child_depth>count: count = child_depth count += 1 self._depth = count return self._depth def __iter__(self): yield self for c in self.children: for x in c: yield xdef head_to_tree(head, tokens, len_, prune, subj_pos, obj_pos): """ Convert a sequence of head indexes into a tree object. """ tokens = tokens[:len_] head = head[:len_] # tokens = tokens[:len_].tolist() # head = head[:len_].tolist() root = None if prune < 0: nodes = [Tree() for _ in head] for i in range(len(nodes)): h = head[i] nodes[i].idx = i nodes[i].dist = -1 # just a filler if h == 0: root = nodes[i] else: nodes[h-1].add_child(nodes[i]) else: # find dependency path subj_pos = [i for i in range(len_) if subj_pos[i] == 0] obj_pos = [i for i in range(len_) if obj_pos[i] == 0] cas = None subj_ancestors = set(subj_pos) for s in subj_pos: h = head[s] # print(h) tmp = [s] while h > 0: tmp += [h-1] subj_ancestors.add(h-1) h = head[h-1] if cas is None: cas = set(tmp) else: cas.intersection_update(tmp) obj_ancestors = set(obj_pos) for o in obj_pos: h = head[o] tmp = [o] while h > 0: tmp += [h-1] obj_ancestors.add(h-1) h = head[h-1] cas.intersection_update(tmp) # find lowest common ancestor if len(cas) == 1: lca = list(cas)[0] else: child_count = {k:0 for k in cas} for ca in cas: if head[ca] > 0 and head[ca] - 1 in cas: child_count[head[ca] - 1] += 1 # the LCA has no child in the CA set for ca in cas: if child_count[ca] == 0: lca = ca break path_nodes = subj_ancestors.union(obj_ancestors).difference(cas) path_nodes.add(lca) # compute distance to path_nodes dist = [-1 if i not in path_nodes else 0 for i in range(len_)] for i in range(len_): if dist[i] < 0: stack = [i] while stack[-1] >= 0 and stack[-1] not in path_nodes: stack.append(head[stack[-1]] - 1) if stack[-1] in path_nodes: for d, j in enumerate(reversed(stack)): dist[j] = d else: for j in stack: if j >= 0 and dist[j] < 0: dist[j] = int(1e4) # aka infinity highest_node = lca nodes = [Tree() if dist[i] <= prune else None for i in range(len_)] for i in range(len(nodes)): if nodes[i] is None: continue h = head[i] nodes[i].idx = i nodes[i].dist = dist[i] if h > 0 and i != highest_node: assert nodes[h-1] is not None nodes[h-1].add_child(nodes[i]) root = nodes[highest_node] assert root is not None return rootdef tree_to_adj(sent_len, tree, directed=True, self_loop=False): """ Convert a tree object to an (numpy) adjacency matrix. """ ret = np.zeros((sent_len, sent_len), dtype=np.float32) queue = [tree] idx = [] while len(queue) > 0: t, queue = queue[0], queue[1:] idx += [t.idx] for c in t.children: ret[t.idx, c.idx] = 1 queue += t.children if not directed: ret = ret + ret.T if self_loop: for i in idx: ret[i, i] = 1 return retdef tree_to_dist(sent_len, tree): ret = -1 * np.ones(sent_len, dtype=np.int64) for node in tree: ret[node.idx] = node.dist return retdef get_positions(start_idx, end_idx, length): """ Get subj/obj position sequence. """ return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \ list(range(1, length-end_idx))if __name__ == "__main__": prune=1 head=["2", "3", "0", "8", "7", "7", "8", "3", "3", "3", "13", "13", "20", "17", "17", "17", "13", "20", "20", "3", "23", "23", "20", "3"] words=["neg", "nsubj", "ROOT", "advmod", "compound", "compound", "nsubj", "ccomp", "punct", "cc", "det", "amod", "nsubjpass", "case", "det", "compound", "nmod", "aux", "auxpass", "conj", "case", "nmod:poss", "nmod", "punct"] head = [int(x) for x in head] subj_pos=get_positions(21,21,len(head)) obj_pos=get_positions(1,1,len(head)) l=len(head) # l=[24] # subj_pos=[] # obj_pos=[] tree=head_to_tree(head, words, l, prune, subj_pos, obj_pos) print(tree) print(subj_pos) print(obj_pos) maxlen = len(head) adj=tree_to_adj(maxlen, tree, directed=False, self_loop=False).reshape(1, maxlen, maxlen) print(adj.shape) # trees = [head_to_tree(head[i], words[i], l[i], prune, subj_pos[i], obj_pos[i]) for i in range(len(l))]
它主要是构建了一个Tree的对象,然后再把Tree这个对象构成邻接矩阵就行了,注意看subj_pos和obj_pos数组的生成:
<__main__.Tree object at 0x7f167dafb240>[-21, -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2][-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22](1, 24, 24)
其中实体所在的位置为0,其他的就围绕实体的位置进行排列,构建tree的时候用到了这个信息,是不是很巧妙,细节的话读者可以自己去琢磨,一步一步的debug就行了。
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~