c语言sscanf函数的用法是什么
251
2022-09-21
Virtual Adversarial Training的pytorch实现
def vat_loss(embedder, encoder, clf, batch, perturb_norm_length=5.0, small_constant_for_finite_diff=1e-1, Ip=1, p_logit=None): embedded = embedder(batch) # [seq_len,batch,hidden] d = torch.randn(embedded.shape).type(batch_utils.FLOAT_TYPE) d = d.transpose(0, 1).contiguous() d = get_normalized_vector(d).transpose(0, 1).contiguous() # [seq_len,batch,hidden] for ip in range(Ip): x_d = Variable(embedded.data + (small_constant_for_finite_diff * d), requires_grad=True) x_d.retain_grad() p_d_logit = clf(encoder(x_d, batch)[0]) kl_loss = kl_categorical(Variable(p_logit.data), p_d_logit) kl_loss.backward() d = x_d.grad.data.transpose(0, 1).contiguous() d = get_normalized_vector(d).transpose(0, 1).contiguous() x_adv = embedded + (perturb_norm_length * Variable(d)) p_adv_logit = clf(encoder(x_adv, batch)[0]) return kl_categorical(Variable(p_logit.data), p_adv_logit)
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~