pytorch版本CSNet运行octa数据集的问题

网友投稿 292 2022-08-27

pytorch版本CSNet运行octa数据集的问题

今天跑了一下CSNet的pytorch的代码,​​​ train.py基本没多大改动:

"""Training script for CS-Net"""import osimport torchimport torch.nn as nnfrom torch import optimfrom torch.utils.data import DataLoaderimport visdomimport numpy as npfrom model.csnet import CSNetfrom dataloader.octa import Datafrom utils.train_metrics import metricsfrom utils.visualize import init_visdom_line, update_lines# os.environ["CUDA_VISIBLE_DEVICES"] = "1"args = { 'root' : '', 'data_path' : 'dataset/octa/', 'epochs' : 1000, 'lr' : 0.0001, 'snapshot' : 100, 'test_step' : 1, 'ckpt_path' : 'checkpoint/', 'batch_size': 8,}# # Visdom---------------------------------------------------------X, Y = 0, 0.5 # for visdomx_acc, y_acc = 0, 0x_sen, y_sen = 0, 0env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss")env1, panel1 = init_visdom_line(x_acc, y_acc, title="Accuracy", xlabel="iters", ylabel="accuracy")env2, panel2 = init_visdom_line(x_sen, y_sen, title="Sensitivity", xlabel="iters", ylabel="sensitivity")# # ---------------------------------------------------------------def save_ckpt(net, iter): if not os.path.exists(args['ckpt_path']): os.makedirs(args['ckpt_path']) torch.save(net, args['ckpt_path'] + 'CS_Net_DRIVE_' + str(iter) + '.pkl') print('--->saved model:{}<--- '.format(args['root'] + args['ckpt_path']))# adjust learning rate (poly)def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9): lr = base_lr * (1 - float(iter) / max_iter) ** power for param_group in optimizer.param_groups: param_group['lr'] = lrdef train(): # set the channels to 3 when the format is RGB, otherwise 1. net = CSNet(classes=1, channels=1).cuda() net = nn.DataParallel(net, device_ids=[0]).cuda() optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005) critrion = nn.MSELoss().cuda() # critrion = nn.CrossEntropyLoss().cuda() print("---------------start training------------------") # load train dataset train_data = Data(args['data_path'], train=True) batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=2, shuffle=True) iters = 1 accuracy = 0. sensitivty = 0. for epoch in range(args['epochs']): net.train() for idx, batch in enumerate(batchs_data): image = batch[0].cuda() label = batch[1].cuda() optimizer.zero_grad() pred = net(image) # pred = pred.squeeze_(1) print(pred.shape) loss = critrion(pred, label) loss.backward() optimizer.step() acc, sen = metrics(pred, label, pred.shape[0]) print('[{0:d}:{1:d}] --- loss:{2:.10f}\tacc:{3:.4f}\tsen:{4:.4f}'.format(epoch + 1, iters, loss.item(), acc / pred.shape[0], sen / pred.shape[0])) iters += 1 # # ---------------------------------- visdom -------------------------------------------------- X, x_acc, x_sen = iters, iters, iters Y, y_acc, y_sen = loss.item(), acc / pred.shape[0], sen / pred.shape[0] update_lines(env, panel, X, Y) update_lines(env1, panel1, x_acc, y_acc) update_lines(env2, panel2, x_sen, y_sen) # # -------------------------------------------------------------------------------------------- adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9) if (epoch + 1) % args['snapshot'] == 0: save_ckpt(net, epoch + 1) # model eval if (epoch + 1) % args['test_step'] == 0: test_acc, test_sen = model_eval(net) print("Average acc:{0:.4f}, average sen:{1:.4f}".format(test_acc, test_sen)) if (accuracy > test_acc) & (sensitivty > test_sen): save_ckpt(net, epoch + 1 + 8888888) accuracy = test_acc sensitivty = test_sendef model_eval(net): print("Start testing model...") test_data = Data(args['data_path'], train=False) batchs_data = DataLoader(test_data, batch_size=1) net.eval() Acc, Sen = [], [] file_num = 0 for idx, batch in enumerate(batchs_data): image = batch[0].float().cuda() label = batch[1].float().cuda() pred_val = net(image) acc, sen = metrics(pred_val, label, pred_val.shape[0]) print("\t---\t test acc:{0:.4f} test sen:{1:.4f}".format(acc, sen)) Acc.append(acc) Sen.append(sen) file_num += 1 # for better view, add testing visdom here. return np.mean(Acc), np.mean(Sen)if __name__ == '__main__': train()

predict.py去除了crop操作:

import torchfrom torchvision import transformsfrom PIL import Image, ImageOpsimport numpy as npimport scipy.misc as miscimport osimport globfrom utils.misc import thresh_OTSU, ReScaleSize, Cropfrom utils.model_eval import evalDATABASE = './octa/'#args = { 'root' : './dataset/' + DATABASE, 'test_path': './dataset/' + DATABASE + 'training/', 'pred_path': 'assets/' + 'octa/', 'img_size' : 512}if not os.path.exists(args['pred_path']): os.makedirs(args['pred_path'])def rescale(img): w, h = img.size min_len = min(w, h) new_w, new_h = min_len, min_len scale_w = (w - new_w) // 2 scale_h = (h - new_h) // 2 box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) img = img.crop(box) return imgdef ReScaleSize_DRIVE(image, re_size=512): w, h = image.size min_len = min(w, h) new_w, new_h = min_len, min_len scale_w = (w - new_w) // 2 scale_h = (h - new_h) // 2 box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h) image = image.crop(box) image = image.resize((re_size, re_size)) return image # , origin_w, origin_hdef ReScaleSize_STARE(image, re_size=512): w, h = image.size max_len = max(w, h) new_w, new_h = max_len, max_len delta_w = new_w - w delta_h = new_h - h padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) image = ImageOps.expand(image, padding, fill=0) # origin_w, origin_h = w, h image = image.resize((re_size, re_size)) return image # , origin_w, origin_hdef load_nerve(): test_images = [] test_labels = [] for file in glob.glob(os.path.join(args['test_path'], 'orig', '*.tif')): basename = os.path.basename(file) file_name = basename[:-4] image_name = os.path.join(args['test_path'], 'orig', basename) label_name = os.path.join(args['test_path'], 'mask2', file_name + '_centerline_overlay.tif') test_images.append(image_name) test_labels.append(label_name) return test_images, test_labelsdef load_drive(): test_images = [] test_labels = [] for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')): basename = os.path.basename(file) file_name = basename[:3] image_name = os.path.join(args['test_path'], 'images', basename) label_name = os.path.join(args['test_path'], '1st_manual', file_name + 'manual1.gif') test_images.append(image_name) test_labels.append(label_name) return test_images, test_labelsdef load_stare(): test_images = [] test_labels = [] for file in glob.glob(os.path.join(args['test_path'], 'images', '*.ppm')): basename = os.path.basename(file) file_name = basename[:-4] image_name = os.path.join(args['test_path'], 'images', basename) label_name = os.path.join(args['test_path'], 'labels-ah', file_name + '.ah.ppm') test_images.append(image_name) test_labels.append(label_name) return test_images, test_labelsdef load_padova1(): test_images = [] test_labels = [] for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')): basename = os.path.basename(file) file_name = basename[:-4] image_name = os.path.join(args['test_path'], 'images', basename) label_name = os.path.join(args['test_path'], 'label2', file_name + '_centerline_overlay.tif') test_images.append(image_name) test_labels.append(label_name) return test_images, test_labelsdef load_octa(): test_images = [] test_labels = [] for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')): basename = os.path.basename(file) file_name = basename[:-4] # print(file_name) image_name = os.path.join(args['test_path'], 'images', basename) # label_name = os.path.join(args['test_path'], 'label', file_name + '_nerve_ann.tif') label_name = os.path.join(args['test_path'], 'label', file_name + '.png') test_images.append(image_name) test_labels.append(label_name) return test_images, test_labelsdef load_net(): net = torch.load('./checkpoint/CS_Net_DRIVE_200.pkl') return netdef save_prediction(pred, filename=''): save_path = args['pred_path'] + 'pred/' if not os.path.exists(save_path): os.makedirs(save_path) print("Make dirs success!") mask = pred.data.cpu().numpy() * 255 print(mask.shape) mask = np.transpose(np.squeeze(mask, axis=0), [1, 2, 0]) print(mask.shape) mask = np.squeeze(mask, axis=-1) print(mask.shape) misc.imsave(save_path + filename + '.png', mask)def predict(): net = load_net() # images, labels = load_nerve() # images, labels = load_drive() # images, labels = load_stare() # images, labels = load_padova1() images, labels = load_octa() transform = transforms.Compose([ transforms.ToTensor() ]) with torch.no_grad(): net.eval() for i in range(len(images)): print(images[i]) name_list = images[i].split('/') index = name_list[-1][:-4] image = Image.open(images[i]) # image=image.convert("RGB") label = Image.open(labels[i]) # image, label = center_crop(image, label) # for other retinal vessel # image = rescale(image) # label = rescale(label) # image = ReScaleSize_STARE(image, re_size=args['img_size']) # label = ReScaleSize_DRIVE(label, re_size=args['img_size']) # for OCTA image = ReScaleSize(image) label = ReScaleSize(label) # misc.imsave(str(index) + '_pred.png', label) # print(label) label.save('output/'+str(index) + '_pred.png') # label = label.resize((args['img_size'], args['img_size'])) # if cuda image = transform(image).cuda() # image = transform(image) image = image.unsqueeze(0) output = net(image) save_prediction(output, filename=index + '_pred') print("output saving successfully")if __name__ == '__main__': predict() thresh_OTSU(args['pred_path'] + 'pred/')

然后就是把octa.py的crop去掉就行了哈:

from __future__ import print_function, divisionimport osimport globfrom torch.utils.data import Datasetfrom torchvision import transformsfrom PIL import Image, ImageEnhance, ImageOpsimport randomimport warningswarnings.filterwarnings('ignore')def load_dataset(root_dir, train=True): labels = [] images = [] if train: sub_dir = 'training' else: sub_dir = 'test' label_path = os.path.join(root_dir, sub_dir, 'label') image_path = os.path.join(root_dir, sub_dir, 'images') for file in glob.glob(os.path.join(image_path, '*.tif')): image_name = os.path.basename(file) # label_name = image_name[:-4] + '_nerve_ann.tif' label_name = image_name[:-4] + '.png' labels.append(os.path.join(label_path, label_name)) images.append(os.path.join(image_path, image_name)) return images, labelsclass Data(Dataset): def __init__(self, root_dir, train=True, rotate=45, flip=True, random_crop=True, scale1=512): self.root_dir = root_dir self.train = train self.rotate = rotate self.flip = flip self.random_crop = random_crop self.transform = transforms.ToTensor() self.resize = scale1 self.images, self.groundtruth = load_dataset(self.root_dir, self.train) def __len__(self): return len(self.images) def RandomCrop(self, image, label, crop_size): crop_width, crop_height = crop_size w, h = image.size left = random.randint(0, w - crop_width) top = random.randint(0, h - crop_height) right = left + crop_width bottom = top + crop_height new_image = image.crop((left, top, right, bottom)) new_label = label.crop((left, top, right, bottom)) return new_image, new_label def RandomEnhance(self, image): value = random.uniform(-2, 2) random_seed = random.randint(1, 4) if random_seed == 1: img_enhanceed = ImageEnhance.Brightness(image) elif random_seed == 2: img_enhanceed = ImageEnhance.Color(image) elif random_seed == 3: img_enhanceed = ImageEnhance.Contrast(image) else: img_enhanceed = ImageEnhance.Sharpness(image) image = img_enhanceed.enhance(value) return image def Crop(self, image): left = 261 top = 1 right = 1110 bottom = 850 image = image.crop((left, top, right, bottom)) return image def ReScaleSize(self, image, re_size=512): w, h = image.size max_len = max(w, h) new_w, new_h = max_len, max_len delta_w = new_w - w delta_h = new_h - h padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) image = ImageOps.expand(image, padding, fill=0) # origin_w, origin_h = w, h image = image.resize((re_size, re_size)) return image # , origin_w, origin_h def __getitem__(self, idx): img_path = self.images[idx] gt_path = self.groundtruth[idx] image = Image.open(img_path) label = Image.open(gt_path) # print(image.size) # image = self.Crop(image) # label = self.Crop(label) image = self.ReScaleSize(image, self.resize) label = self.ReScaleSize(label, self.resize) if self.train: # augumentation angel = random.randint(-self.rotate, self.rotate) image = image.rotate(angel) label = label.rotate(angel) if random.random() > 0.5: image = self.RandomEnhance(image) image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize]) # flip if self.flip and random.random() > 0.5: image = image.transpose(Image.FLIP_LEFT_RIGHT) label = label.transpose(Image.FLIP_LEFT_RIGHT) else: img_size = image.size if img_size[0] != self.resize: image = image.resize((self.resize, self.resize)) label = label.resize((self.resize, self.resize)) image = self.transform(image) label = self.transform(label) return image,

其他地方基本没动哈。 代码的运行命令为:

python -m visdom. serverpython train.pypython predict.py

然后asets/octa/pred目录就有预测出来的图片哈。

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:快手较劲京东发力,谁是抖音真正的对手?(快手竞争不过抖音)
下一篇:脱下“营销外套”的足力健!(足力健导购)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~