博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Unet 项目部分代码学习
阅读量:5999 次
发布时间:2019-06-20

本文共 9032 字,大约阅读时间需要 30 分钟。

github地址:https://github.com/orobix/retina-unet

主程序:

#####################################################   Script to:#   - Load the images and extract the patches#   - Define the neural network#   - define the training###################################################import numpy as npimport configparser as ConfigParserfrom keras.models import Modelfrom keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropoutfrom keras.optimizers import Adamfrom keras.callbacks import ModelCheckpoint, LearningRateSchedulerfrom keras import backend as Kfrom keras.utils.vis_utils import plot_model as plotfrom keras.optimizers import SGDimport syssys.path.insert(0, './lib/')from help_functions import *#function to obtain data for training/testing (validation)from extract_patches import get_data_training#Define the neural networkdef get_unet(n_ch,patch_height,patch_width):    inputs = Input(shape=(n_ch,patch_height,patch_width))    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(inputs)    conv1 = Dropout(0.2)(conv1)    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv1)    pool1 = MaxPooling2D((2, 2))(conv1)    #    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool1)    conv2 = Dropout(0.2)(conv2)    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv2)    pool2 = MaxPooling2D((2, 2))(conv2)    #    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool2)    conv3 = Dropout(0.2)(conv3)    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv3)    up1 = UpSampling2D(size=(2, 2))(conv3)    up1 = concatenate([conv2,up1],axis=1)    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(up1)    conv4 = Dropout(0.2)(conv4)    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv4)    #    up2 = UpSampling2D(size=(2, 2))(conv4)    up2 = concatenate([conv1,up2], axis=1)    conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(up2)    conv5 = Dropout(0.2)(conv5)    conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv5)    #    conv6 = Conv2D(2, (1, 1), activation='relu',padding='same',data_format='channels_first')(conv5)    conv6 = core.Reshape((2,patch_height*patch_width))(conv6)    conv6 = core.Permute((2,1))(conv6)    ############    conv7 = core.Activation('softmax')(conv6)    model = Model(inputs=inputs, outputs=conv7)    # sgd = SGD(lr=0.01, decay=1e-6, momentum=0.3, nesterov=False)    model.compile(optimizer='sgd', loss='categorical_crossentropy',metrics=['accuracy'])    return model#Define the neural network gnet#you need change function call "get_unet" to "get_gnet" in line 166 before use this networkdef get_gnet(n_ch,patch_height,patch_width):    inputs = Input((n_ch, patch_height, patch_width))    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)    conv1 = Dropout(0.2)(conv1)    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)    up1 = UpSampling2D(size=(2, 2))(conv1)    #    conv2 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(up1)    conv2 = Dropout(0.2)(conv2)    conv2 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(conv2)    pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)    #    conv3 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(pool1)    conv3 = Dropout(0.2)(conv3)    conv3 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv3)    pool2 = MaxPooling2D(pool_size=(2, 2))(conv3)    #    conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool2)    conv4 = Dropout(0.2)(conv4)    conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv4)    pool3 = MaxPooling2D(pool_size=(2, 2))(conv4)    #    conv5 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool3)    conv5 = Dropout(0.2)(conv5)    conv5 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv5)    #    up2 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)    conv6 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up2)    conv6 = Dropout(0.2)(conv6)    conv6 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv6)    #    up3 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)    conv7 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up3)    conv7 = Dropout(0.2)(conv7)    conv7 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv7)    #    up4 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)    conv8 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(up4)    conv8 = Dropout(0.2)(conv8)    conv8 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(conv8)    #    pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(pool4)    conv9 = Dropout(0.2)(conv9)    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)    #    conv10 = Convolution2D(2, 1, 1, activation='relu', border_mode='same')(conv9)    conv10 = core.Reshape((2,patch_height*patch_width))(conv10)    conv10 = core.Permute((2,1))(conv10)    ############    conv10 = core.Activation('softmax')(conv10)    model = Model(input=inputs, output=conv10)    # sgd = SGD(lr=0.01, decay=1e-6, momentum=0.3, nesterov=False)    model.compile(optimizer='sgd', loss='categorical_crossentropy',metrics=['accuracy'])    return model#========= Load settings from Config fileconfig = ConfigParser.RawConfigParser()config.read('configuration.txt')#patch to the datasetspath_data = config.get('data paths', 'path_local')#Experiment namename_experiment = config.get('experiment name', 'name')#training settingsN_epochs = int(config.get('training settings', 'N_epochs'))batch_size = int(config.get('training settings', 'batch_size'))#============ Load the data and divided in patchespatches_imgs_train, patches_masks_train = get_data_training(    DRIVE_train_imgs_original = path_data + config.get('data paths', 'train_imgs_original'),    DRIVE_train_groudTruth = path_data + config.get('data paths', 'train_groundTruth'),  #masks    patch_height = int(config.get('data attributes', 'patch_height')),    patch_width = int(config.get('data attributes', 'patch_width')),    N_subimgs = int(config.get('training settings', 'N_subimgs')),    inside_FOV = config.getboolean('training settings', 'inside_FOV') #select the patches only inside the FOV  (default == True))#========= Save a sample of what you're feeding to the neural network ==========N_sample = min(patches_imgs_train.shape[0],40)#这里规定,要显示的图片最多不超过40张visualize(group_images(patches_imgs_train[0:N_sample,:,:,:],5),'./'+name_experiment+'/'+"sample_input_imgs")#.show()visualize(group_images(patches_masks_train[0:N_sample,:,:,:],5),'./'+name_experiment+'/'+"sample_input_masks")#.show()#显示的结果会在下面贴出来#=========== Construct and save the model arcitecture =====n_ch = patches_imgs_train.shape[1]#得到每个patch的通道数patch_height = patches_imgs_train.shape[2]#得到每个patch的高patch_width = patches_imgs_train.shape[3]#得到每个patch的宽model = get_unet(n_ch, patch_height, patch_width)  #the U-net modelprint ("Check: final output of the network:")print (model.output_shape)plot(model, to_file='./'+name_experiment+'/'+name_experiment + '_model.png')   #check how the model looks likejson_string = model.to_json()#model.to_json:返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型:open('./'+name_experiment+'/'+name_experiment +'_architecture.json', 'w').write(json_string)#============  Training ==================================checkpointer = ModelCheckpoint(filepath='./'+name_experiment+'/'+name_experiment +'_best_weights.h5', verbose=1, monitor='val_loss', mode='auto', save_best_only=True) #save at each epoch if the validation decreased# def step_decay(epoch):#     lrate = 0.01 #the initial learning rate (by default in keras)#     if epoch==100:#         return 0.005#     else:#         return lrate## lrate_drop = LearningRateScheduler(step_decay)patches_masks_train = masks_Unet(patches_masks_train)  #reduce memory consumptionmodel.fit(patches_imgs_train, patches_masks_train, nb_epoch=N_epochs, batch_size=batch_size, verbose=2, shuffle=True, validation_split=0.1, callbacks=[checkpointer])#========== Save and test the last model ===================model.save_weights('./'+name_experiment+'/'+name_experiment +'_last_weights.h5', overwrite=True)#test the model# score = model.evaluate(patches_imgs_test, masks_Unet(patches_masks_test), verbose=0)# print('Test score:', score[0])# print('Test accuracy:', score[1])

实验结果显示:上中下分别为原图-groundTruth-预测图

 

转载于:https://www.cnblogs.com/fourmi/p/8993631.html

你可能感兴趣的文章
Android + Eclipse + PhoneGap 环境配置
查看>>
带色彩恢复的多尺度视网膜增强算法(MSRCR)的原理、实现及应用。
查看>>
从中国电信和中国移动的套餐使用查询业务浅谈数据同步
查看>>
Hadoop 类Grep源代码注释
查看>>
[置顶] Objective-C编程之道iOS设计模式单例解析(2)
查看>>
[Android开发常见问题-16] FragmentActivity cannot be resolve to a type
查看>>
专题实验 Toad 用户的创建与管理( 包括 role 等 )
查看>>
markdown 语法和工具
查看>>
当调用List Remove 失效时 [C#] .
查看>>
Linux下修改Oracle监听地址
查看>>
ie11的仿真模式
查看>>
hdu - 3049 - Data Processing(乘法逆元)
查看>>
Java程序员面试失败的5大原因
查看>>
Open vSwitch 工作原理
查看>>
(算法)两个单词的最短距离
查看>>
谈谈Ext JS的组件——布局的用法续二
查看>>
网络爬虫个人博客
查看>>
json串转对象
查看>>
HTTP代理与SPDY协议(转)
查看>>
线程初步了解 - <第一篇>
查看>>