博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Resnet训练 验证自己的数据集
阅读量:4114 次
发布时间:2019-05-25

本文共 11172 字,大约阅读时间需要 37 分钟。

使用resnet神经网络,在预训练模型的基础上,在自己的数据集上进行微调,最后在自己的数据集上验证效果。

可以借鉴部分:

  • 数据加载
  • resnet模型使用方法
  • 预训练模型加载,模型保存
  • 训练/验证/测试步骤
  • 使用cpu/gpu进行模型训练

需要修改部分:

  • 预训练模型加载路径
  • 自己数据集路径
  • 保存路径
import torch.nn as nnimport mathimport pickleimport torchimport numpy as npimport torchimport torch.nn as nnimport torchvisionimport torchvision.transforms as transformsimport xlwtimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.autograd import Variableimport numpy as npfrom torchvision import modelsimport matplotlib.pyplot as pltfrom PIL import Image__all__ = ['ResNet', 'resnet50']def conv3x3(in_planes, out_planes, stride=1):    """3x3 convolution with padding"""    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,                     padding=1, bias=False)class BasicBlock(nn.Module):    expansion = 1    def __init__(self, inplanes, planes, stride=1, downsample=None):        super(BasicBlock, self).__init__()        self.conv1 = conv3x3(inplanes, planes, stride)        self.bn1 = nn.BatchNorm2d(planes)        self.relu = nn.ReLU(inplace=True)        self.conv2 = conv3x3(planes, planes)        self.bn2 = nn.BatchNorm2d(planes)        self.downsample = downsample        self.stride = stride    def forward(self, x):        residual = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        if self.downsample is not None:            residual = self.downsample(x)        out += residual        out = self.relu(out)        return outclass Bottleneck(nn.Module):    expansion = 4    def __init__(self, inplanes, planes, stride=1, downsample=None):        super(Bottleneck, self).__init__()        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)        self.bn1 = nn.BatchNorm2d(planes)        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)        self.bn2 = nn.BatchNorm2d(planes)        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)        self.bn3 = nn.BatchNorm2d(planes * 4)        self.relu = nn.ReLU(inplace=True)        self.downsample = downsample        self.stride = stride    def forward(self, x):        residual = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        out = self.relu(out)        out = self.conv3(out)        out = self.bn3(out)        if self.downsample is not None:            residual = self.downsample(x)        out += residual        out = self.relu(out)        return outclass ResNet(nn.Module):    def __init__(self, block, layers, num_classes=15, include_top=True):        self.inplanes = 64        super(ResNet, self).__init__()        self.include_top = include_top        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)        self.bn1 = nn.BatchNorm2d(64)        self.relu = nn.ReLU(inplace=True)        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)        self.layer1 = self._make_layer(block, 64, layers[0])        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)        self.avgpool = nn.AvgPool2d(7, stride=1)        self.fc = nn.Linear(512 * block.expansion, num_classes)        for m in self.modules():            if isinstance(m, nn.Conv2d):                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels                m.weight.data.normal_(0, math.sqrt(2. / n))            elif isinstance(m, nn.BatchNorm2d):                m.weight.data.fill_(1)                m.bias.data.zero_()    def _make_layer(self, block, planes, blocks, stride=1):        downsample = None        if stride != 1 or self.inplanes != planes * block.expansion:            downsample = nn.Sequential(                nn.Conv2d(self.inplanes, planes * block.expansion,                          kernel_size=1, stride=stride, bias=False),                nn.BatchNorm2d(planes * block.expansion),            )        layers = []        layers.append(block(self.inplanes, planes, stride, downsample))        self.inplanes = planes * block.expansion        for i in range(1, blocks):            layers.append(block(self.inplanes, planes))        return nn.Sequential(*layers)    def forward(self, x):        x = self.conv1(x)        x = self.bn1(x)        x = self.relu(x)        x = self.maxpool(x)        x = self.layer1(x)        x = self.layer2(x)        x = self.layer3(x)        x = self.layer4(x)        x = self.avgpool(x)        if not self.include_top:            return x        x = x.view(x.size(0), -1)        x = self.fc(x)        return xdef resnet50(**kwargs):    """Constructs a ResNet-50 model.    """    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)    return modeldef load_state_dict(model, fname):    """    Set parameters converted from Caffe models authors of VGGFace2 provide.    See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.    Arguments:        model: model        fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.    """    with open(fname, 'rb') as f:        weights = pickle.load(f, encoding='latin1')    own_state = model.state_dict()    #print(own_state)    for name, param in weights.items():        #print(name)        if name in own_state:            try:                own_state[name].copy_(torch.from_numpy(param))            except Exception:                raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\                                   'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))        else:            raise KeyError('unexpected key "{}" in state_dict'.format(name))if __name__=='__main__':    #model = torchvision.models.resnet50()    weight_file='/train/resNet/resnet50_scratch_weight.pkl'    # weight_file = './resnet50_scratch_weight.pkl'    model_ft = models.resnet50()    # print(model_ft)    model_ft.avgpool = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)    model_ft.fc = nn.Linear(2048, 8631)    load_state_dict(model_ft, weight_file)    model_ft.fc = nn.Linear(2048, 15)    num=0    # print('---'*10)    # print(model_ft)    # for i in model_ft.parameters():    #     num=num+1    #    #     print(i)    batch_size=16    train_transforms = transforms.Compose([        transforms.Resize((224,224)),  # 对图片尺寸做一个缩放切割        transforms.RandomHorizontalFlip(),  # 水平翻转        transforms.Grayscale(num_output_channels=3),        transforms.ToTensor(),  # 转化为张量        # transforms.Normalize((.5, .5, .5), (.5, .5, .5))  # 进行归一化    ])    # 对测试集做变换    val_transforms = transforms.Compose([        transforms.Resize((224,224)),        # transforms.RandomResizedCrop(224),        transforms.Grayscale(num_output_channels=3),        transforms.ToTensor(),        # transforms.Normalize((.5, .5, .5), (.5, .5, .5))    ])    train_dir = "/train"  # 训练集路径    train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)    # 加载数据集    train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)    val_dir = "/val"    val_datasets = datasets.ImageFolder(val_dir, transform=val_transforms)    val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=batch_size, shuffle=True)    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    print(train_datasets.class_to_idx)    # Hyper-parameters    num_epochs = 800    learning_rate = 0.001    if torch.cuda.is_available():        model_ft.cuda()    #     print('++++++++++++++++')    max_acc=0    # params = [{'params': md.parameters()} for md in model_ft.children()    #           if md in [model_ft.classifier]]    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=learning_rate)    loss_func = nn.CrossEntropyLoss()    for epoch in range(200):  # 100        print('epoch= ',epoch)        # training-----------------------------        model_ft.train()        train_loss = 0.        train_acc = 0.        if(epoch!=0 and epoch%10==0):            learning_rate=learning_rate*0.1            optimizer = optim.SGD(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=learning_rate)        for batch_x, batch_y in train_dataloader:            batch_x, batch_y = Variable(batch_x).cuda(), Variable(batch_y).cuda()            out = model_ft(batch_x)            loss = loss_func(out, batch_y)            train_loss += loss.item()            pred = torch.max(out, 1)[1]            # print(pred)            train_correct = (pred == batch_y).sum()            train_acc += train_correct.item()            optimizer.zero_grad()            loss.backward()            optimizer.step()        strr=str(epoch)+"  "+str(train_loss/len(train_datasets)*100)+'    '+str(train_acc/len(train_datasets)*100)        with open('s_loss.txt','a') as f:            f.write(strr+'\n')        # evaluation--------------------------------        # if(epoch%20==0):        if(True):            model_ft.eval()            eval_loss = 0.            eval_acc = 0.            for batch_x, batch_y in val_dataloader:                batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda()                out =model_ft(batch_x)                loss = loss_func(out, batch_y)                eval_loss += loss.item()                pred = torch.max(out, 1)[1]                # print('-------------')                # print(pred)                num_correct = (pred == batch_y).sum()                eval_acc += num_correct.item()            leng = len(val_datasets)            leng=1.0*leng            tmpp=eval_acc/leng*100            print('eval_acc ',tmpp)            strr=str(eval_acc)+"  "+str(tmpp)            with open('s_eval_acc.txt','a') as f:                f.write(strr+'\n')            if(tmpp>max_acc):                max_acc=tmpp                torch.save(model_ft.state_dict(), '/train/resNet/resnet_s_best.pkl')    test_dir = "/test"    test_datasets = datasets.ImageFolder(test_dir, transform=val_transforms)    test_dataloader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True)    test_loss = 0.    test_acc = 0.    for batch_x, batch_y in test_dataloader:        batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda()        out = model_ft(batch_x)        pred = torch.max(out, 1)[1]        num_correct = (pred == batch_y).sum()        test_acc += num_correct.item()        str1 = pred.cuda().data.cpu().numpy()        str2 = batch_y.cuda().data.cpu().numpy()        with open('an_est2.txt', 'a') as f:            f.write(str(str1) + '\n')            f.write(str(str2) + '\n')    leng = len(test_datasets)    leng = 1.0 * leng    tmpp = test_acc / leng * 100    print('eval_acc2 ', tmpp, test_acc, leng)    strr = str(test_acc) + "  " + str(tmpp)    with open('s_eval_test2.txt', 'a') as f:        f.write(strr + '\n')

转载地址:http://qkgsi.baihongyu.com/

你可能感兴趣的文章
c# ==与equals有什么区别
查看>>
Golang面试考题记录 ━━ 两数之和 ,能一遍循环就一遍循环
查看>>
Golang面试考题记录 ━━ 旋转图像~~二维数组旋转90度
查看>>
Golang面试考题记录 ━━ 有效的数独,没发现什么特别好的算法,就是暴力,结果也差不多
查看>>
Golang面试考题记录 ━━ 反转字符串,一种思路几种细节的不同结果
查看>>
Golang面试考题记录 ━━ 整数反转 解答及扩展的三个知识点
查看>>
Golang面试考题记录 ━━ 字符串中的第一个唯一字符 ,拓展:ASCII和strings字符串查找的用法
查看>>
Golang面试考题记录 ━━ 有效的字母异位词,久违的双100%,拓展reflect.DeepEqual()用法和[26]int{}的值
查看>>
Golang面试考题记录 ━━ 验证回文串,多种方法涉及双指针、strings、unicode和regexp
查看>>
Golang面试考题记录 ━━ 字符串转换整数 (atoi),知识点ascii、rune、uint8、string、char等转换
查看>>
Golang面试考题记录 ━━ 实现 strStr() 函数,截然不同三种方案,效率都差不多,双100%
查看>>
Golang面试考题记录 ━━ 外观数列 , 了解递归、bytes.Buffer和闭包
查看>>
学习日志 ━━ 理解递归(使用go语法举例)
查看>>
Golang面试考题记录 ━━ 最长公共前缀,字符串就是切片,复习[]byte、[]rune、[]uint8、[]int32和单引号
查看>>
Golang学习日志 ━━ 单向链表
查看>>
Golang面试考题记录 ━━ 删除链表中的节点,首先明白什么是链表,其次语文要好能看懂题
查看>>
买股票就是为卖好价钱 十种不应下单的情况
查看>>
用.NET建立Office Add-in
查看>>
数码相片冲印尺寸对照表
查看>>
用Photoshop制作1寸和2寸的照片
查看>>