這篇文章將為大家詳細(xì)講解有關(guān)pytorch如何實(shí)現(xiàn)inception_v3,小編覺(jué)得挺實(shí)用的,因此分享給大家做個(gè)參考,希望大家閱讀完這篇文章后可以有所收獲。
如下所示:
from __future__ import print_function from __future__ import division import torch import torch.nn as nn import torch.optim as optim import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import os import copy import argparse print("PyTorch Version: ",torch.__version__) print("Torchvision Version: ",torchvision.__version__) # Top level data directory. Here we assume the format of the directory conforms # to the ImageFolder structure
數(shù)據(jù)集路徑,路徑下的數(shù)據(jù)集分為訓(xùn)練集和測(cè)試集,也就是train 以及val,train下分為兩類數(shù)據(jù)1,2,val集同理
data_dir = "/home/dell/Desktop/data/切割圖像" # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] model_name = "inception" # Number of classes in the dataset num_classes = 2#兩類數(shù)據(jù)1,2 # Batch size for training (change depending on how much memory you have) batch_size = 32#batchsize盡量選取合適,否則訓(xùn)練時(shí)會(huì)內(nèi)存溢出 # Number of epochs to train for num_epochs = 1000 # Flag for feature extracting. When False, we finetune the whole model, # when True we only update the reshaped layer params feature_extract = True # 參數(shù)設(shè)置,使得我們能夠手動(dòng)輸入命令行參數(shù),就是讓風(fēng)格變得和Linux命令行差不多 parser = argparse.ArgumentParser(description='PyTorch inception') parser.add_argument('--outf', default='/home/dell/Desktop/dj/inception/', help='folder to output images and model checkpoints') #輸出結(jié)果保存路徑 parser.add_argument('--net', default='/home/dell/Desktop/dj/inception/inception.pth', help="path to net (to continue training)") #恢復(fù)訓(xùn)練時(shí)的模型路徑 args = parser.parse_args()
訓(xùn)練函數(shù)
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,is_inception=False): since = time.time() val_acc_history = [] best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 print("Start Training, InceptionV3!") with open("acc.txt", "w") as f1: with open("log.txt", "w")as f2: for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch+1, num_epochs)) print('*' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): if is_inception and phase == 'train': # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 outputs, aux_outputs = model(inputs) loss1 = criterion(outputs, labels) loss2 = criterion(aux_outputs, labels) loss = loss1 + 0.4*loss2 else: outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) f2.write('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) f2.write('\n') f2.flush() # deep copy the model if phase == 'val': if (epoch+1)%50==0: #print('Saving model......') torch.save(model.state_dict(), '%s/inception_%03d.pth' % (args.outf, epoch + 1)) f1.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, epoch_acc)) f1.write('\n') f1.flush() if phase == 'val' and epoch_acc > best_acc: f3 = open("best_acc.txt", "w") f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1,epoch_acc)) f3.close() best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) if phase == 'val': val_acc_history.append(epoch_acc) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) return model, val_acc_history #是否更新參數(shù) def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True): # Initialize these variables which will be set in this if statement. Each of these # variables is model specific. model_ft = None input_size = 0 if model_name == "resnet": """ Resnet18 """ model_ft = models.resnet18(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "alexnet": """ Alexnet """ model_ft = models.alexnet(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier[6].in_features model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) input_size = 224 elif model_name == "vgg": """ VGG11_bn """ model_ft = models.vgg11_bn(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier[6].in_features model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) input_size = 224 elif model_name == "squeezenet": """ Squeezenet """ model_ft = models.squeezenet1_0(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) model_ft.num_classes = num_classes input_size = 224 elif model_name == "densenet": """ Densenet """ model_ft = models.densenet121(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier.in_features model_ft.classifier = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "inception": """ Inception v3 Be careful, expects (299,299) sized images and has auxiliary output """ model_ft = models.inception_v3(pretrained=use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) # Handle the auxilary net num_ftrs = model_ft.AuxLogits.fc.in_features model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) # Handle the primary net num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs,num_classes) input_size = 299 else: print("Invalid model name, exiting...") exit() return model_ft, input_size # Initialize the model for this run model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True) # Print the model we just instantiated #print(model_ft) #準(zhǔn)備數(shù)據(jù) data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } print("Initializing Datasets and Dataloaders...") # Create training and validation datasets image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} # Create training and validation dataloaders dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=0) for x in ['train', 'val']} # Detect if we have a GPU available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ''' 是否加載之前訓(xùn)練過(guò)的模型 we='/home/dell/Desktop/dj/inception_050.pth' model_ft.load_state_dict(torch.load(we)) ''' # Send the model to GPU model_ft = model_ft.to(device) params_to_update = model_ft.parameters() print("Params to learn:") if feature_extract: params_to_update = [] for name,param in model_ft.named_parameters(): if param.requires_grad == True: params_to_update.append(param) print("\t",name) else: for name,param in model_ft.named_parameters(): if param.requires_grad == True: print("\t",name) # Observe that all parameters are being optimized optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs #exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.95) # Setup the loss fxn criterion = nn.CrossEntropyLoss() # Train and evaluate model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception")) ''' #隨機(jī)初始化時(shí)的訓(xùn)練程序 # Initialize the non-pretrained version of the model used for this run scratch_model,_ = initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=False) scratch_model = scratch_model.to(device) scratch_optimizer = optim.SGD(scratch_model.parameters(), lr=0.001, momentum=0.9) scratch_criterion = nn.CrossEntropyLoss() _,scratch_hist = train_model(scratch_model, dataloaders_dict, scratch_criterion, scratch_optimizer, num_epochs=num_epochs, is_inception=(model_name=="inception")) # Plot the training curves of validation accuracy vs. number # of training epochs for the transfer learning method and # the model trained from scratch ohist = [] shist = [] ohist = [h.cpu().numpy() for h in hist] shist = [h.cpu().numpy() for h in scratch_hist] plt.title("Validation Accuracy vs. Number of Training Epochs") plt.xlabel("Training Epochs") plt.ylabel("Validation Accuracy") plt.plot(range(1,num_epochs+1),ohist,label="Pretrained") plt.plot(range(1,num_epochs+1),shist,label="Scratch") plt.ylim((0,1.)) plt.xticks(np.arange(1, num_epochs+1, 1.0)) plt.legend() plt.show() '''
1.PyTorch是相當(dāng)簡(jiǎn)潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類思維,它讓用戶盡可能地專注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開(kāi)發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問(wèn)題6.入門簡(jiǎn)單
關(guān)于“pytorch如何實(shí)現(xiàn)inception_v3”這篇文章就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,使各位可以學(xué)到更多知識(shí),如果覺(jué)得文章不錯(cuò),請(qǐng)把它分享出去讓更多的人看到。
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。
分享文章:pytorch如何實(shí)現(xiàn)inception_v3-創(chuàng)新互聯(lián)
文章起源:http://www.rwnh.cn/article14/jogde.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供做網(wǎng)站、企業(yè)網(wǎng)站制作、品牌網(wǎng)站設(shè)計(jì)、定制開(kāi)發(fā)、App設(shè)計(jì)、響應(yīng)式網(wǎng)站
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容