-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updating main script. Needs further testing
- Loading branch information
Showing
6 changed files
with
167 additions
and
268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,65 @@ | ||
-------------------------------------------------------------------------------- | ||
-- Prepare data model | ||
-------------------------------------------------------------------------------- | ||
paths.mkdir(opt.save) | ||
|
||
trainCache = paths.concat(opt.save_base,'trainCache.t7') | ||
testCache = paths.concat(opt.save_base,'testCache.t7') | ||
local trainCache = paths.concat(rundir,'trainCache.t7') | ||
--testCache = paths.concat(opt.save_base,'testCache.t7') | ||
|
||
local pooler | ||
local feat_dim | ||
--[[ | ||
if opt.algo == 'SPP' then | ||
local conv_list = features:findModules(opt.backend..'.SpatialConvolution') | ||
local num_chns = conv_list[#conv_list].nOutputPlane | ||
pooler = model:get(2):clone():float() | ||
local pyr = torch.Tensor(pooler.pyr):t() | ||
local pooled_size = pyr[1]:dot(pyr[2]) | ||
feat_dim = {num_chns*pooled_size} | ||
elseif opt.algo == 'RCNN' then | ||
feat_dim = {3,227,227} | ||
end | ||
--]] | ||
|
||
image_transformer = nnf.ImageTransformer{mean_pix=image_mean} | ||
local config = paths.dofile('config.lua') | ||
|
||
image_transformer = nnf.ImageTransformer(config.image_transformer_params) | ||
|
||
local FP = nnf[opt.algo] | ||
local fp_params = config.algo[opt.algo].fp_params | ||
local bp_params = config.algo[opt.algo].bp_params | ||
local BP = config.algo[opt.algo].bp | ||
|
||
if paths.filep(trainCache) then | ||
print('Loading train metadata from cache') | ||
batch_provider = torch.load(trainCache) | ||
feat_provider = batch_provider.feat_provider | ||
ds_train = feat_provider.dataset | ||
feat_provider.model = features | ||
else | ||
ds_train = nnf.DataSetPascal{image_set='trainval',classes=classes,year=opt.year, | ||
datadir=opt.datadir,roidbdir=opt.roidbdir} | ||
|
||
|
||
feat_provider = FP(fp_params) | ||
batch_provider = BP(bp_params) | ||
batch_provider:setupData() | ||
local train_params = config.train_params | ||
|
||
torch.save(trainCache,batch_provider) | ||
feat_provider.model = features | ||
-- add common parameters | ||
fp_params.image_transformer = image_transformer | ||
for k,v in pairs(train_params) do | ||
bp_params[k] = v | ||
end | ||
|
||
if paths.filep(testCache) then | ||
print('Loading test metadata from cache') | ||
batch_provider_test = torch.load(testCache) | ||
feat_provider_test = batch_provider_test.feat_provider | ||
ds_test = feat_provider_test.dataset | ||
feat_provider_test.model = features | ||
else | ||
ds_test = nnf.DataSetPascal{image_set='test',classes=classes,year=opt.year, | ||
datadir=opt.datadir,roidbdir=opt.roidbdir} | ||
------------------------------------------------------------------------------- | ||
-- Create structures | ||
-------------------------------------------------------------------------------- | ||
|
||
ds_train = nnf.DataSetPascal{ | ||
image_set='trainval', | ||
year=2007,--opt.year, | ||
datadir=config.datasetDir, | ||
roidbdir=config.roidbDir | ||
} | ||
|
||
feat_provider_test = FP(fp_params) | ||
-- disable flip ? | ||
bp_params.do_flip = false | ||
batch_provider_test = BP(bp_params) | ||
feat_provider = FP(fp_params) | ||
feat_provider:training() | ||
|
||
batch_provider_test:setupData() | ||
|
||
torch.save(testCache,batch_provider_test) | ||
feat_provider_test.model = features | ||
bp_params.dataset = ds_train | ||
bp_params.feat_provider = feat_provider | ||
batch_provider = BP(bp_params) | ||
|
||
if paths.filep(trainCache) then | ||
print('Loading train metadata from cache') | ||
local metadata = torch.load(trainCache) | ||
batch_provider.bboxes = metadata | ||
else | ||
batch_provider:setupData() | ||
torch.save(trainCache, batch_provider.bboxes) | ||
end | ||
|
||
-- compute feature cache | ||
-- test | ||
ds_test = nnf.DataSetPascal{ | ||
image_set='test', | ||
year=2007,--opt.year, | ||
datadir=config.datasetDir, | ||
roidbdir=config.roidbDir | ||
} | ||
|
||
features = nil | ||
model = nil | ||
-- only needed because of SPP | ||
-- could be the same as the one for training | ||
--feat_provider_test = FP(fp_params) | ||
--feat_provider_test:evaluate() | ||
|
||
collectgarbage() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,55 @@ | ||
require 'nnf' | ||
require 'cunn' | ||
--require 'cunn' | ||
require 'optim' | ||
require 'trepl' | ||
|
||
local opts = paths.dofile('opts.lua') | ||
opt = opts.parse(arg) | ||
print(opt) | ||
|
||
if opt.seed ~= 0 then | ||
torch.manualSeed(opt.seed) | ||
cutorch.manualSeed(opt.seed) | ||
if opt.gpu > 0 then | ||
cutorch.manualSeed(opt.seed) | ||
end | ||
end | ||
|
||
cutorch.setDevice(opt.gpu) | ||
torch.setnumthreads(opt.numthreads) | ||
|
||
-------------------------------------------------------------------------------- | ||
-- Select target classes | ||
-------------------------------------------------------------------------------- | ||
|
||
if opt.classes == 'all' then | ||
classes={'aeroplane','bicycle','bird','boat','bottle','bus','car', | ||
'cat','chair','cow','diningtable','dog','horse','motorbike', | ||
'person','pottedplant','sheep','sofa','train','tvmonitor'} | ||
local tensor_type | ||
if opt.gpu > 0 then | ||
require 'cunn' | ||
cutorch.setDevice(opt.gpu) | ||
tensor_type = 'torch.CudaTensor' | ||
print('Using GPU mode on device '..opt.gpu) | ||
else | ||
classes = {opt.classes} | ||
require 'nn' | ||
tensor_type = 'torch.FloatTensor' | ||
print('Using CPU mode') | ||
end | ||
|
||
-------------------------------------------------------------------------------- | ||
|
||
model, criterion = paths.dofile('model.lua') | ||
model:type(tensor_type) | ||
criterion:type(tensor_type) | ||
|
||
paths.dofile('model.lua') | ||
-- prepate training and test data | ||
paths.dofile('data.lua') | ||
|
||
-------------------------------------------------------------------------------- | ||
-- Prepare training model | ||
-------------------------------------------------------------------------------- | ||
-- Do training | ||
paths.dofile('train.lua') | ||
|
||
ds_train.roidb = nil | ||
collectgarbage() | ||
collectgarbage() | ||
|
||
-------------------------------------------------------------------------------- | ||
-- Do full evaluation | ||
-------------------------------------------------------------------------------- | ||
|
||
print('==> Evaluation') | ||
if opt.algo == 'FRCNN' then | ||
tester = nnf.Tester_FRCNN(model,feat_provider_test) | ||
else | ||
tester = nnf.Tester(classifier,feat_provider_test) | ||
end | ||
tester.cachefolder = paths.concat(opt.save,'evaluation',ds_test.dataset_name) | ||
-- evaluation | ||
print('==> Evaluating') | ||
-- add softmax to classifier, because we were using nn.CrossEntropyCriterion | ||
local softmax = nn.SoftMax() | ||
softmax:type(tensor_type) | ||
model:add(softmax) | ||
|
||
feat_provider:evaluate() | ||
|
||
-- define the class to test the model on the full dataset | ||
tester = nnf.Tester(model, feat_provider, ds_test) | ||
tester.cachefolder = rundir | ||
tester:test(opt.num_iter) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,26 @@ | ||
require 'nn' | ||
require 'inn' | ||
require 'cudnn' | ||
local reshapeLastLinearLayer = paths.dofile('utils.lua').reshapeLastLinearLayer | ||
local convertCaffeModelToTorch = paths.dofile('utils.lua').convertCaffeModelToTorch | ||
--require 'inn' | ||
--require 'cudnn' | ||
|
||
-- 1.1. Create Network | ||
local config = opt.netType | ||
local createModel = paths.dofile('models/' .. config .. '.lua') | ||
print('=> Creating model from file: models/' .. config .. '.lua') | ||
model = createModel(opt.backend) | ||
local createModel = paths.dofile('models/' .. opt.netType .. '.lua') | ||
print('=> Creating model from file: models/' .. opt.netType .. '.lua') | ||
local model = createModel() | ||
|
||
-- convert to accept inputs in the range 0-1 RGB format | ||
convertCaffeModelToTorch(model,{1,1}) | ||
local criterion = nn.CrossEntropyCriterion() | ||
|
||
reshapeLastLinearLayer(model,#classes+1) | ||
image_mean = {128/255,128/255,128/255} | ||
|
||
if opt.algo == 'RCNN' then | ||
classifier = model | ||
elseif opt.algo == 'SPP' then | ||
features = model:get(1) | ||
classifier = model:get(3) | ||
elseif opt.algo == 'FRCNN' then | ||
local temp = nn.Sequential() | ||
local features = model:get(1) | ||
local classifier = model:get(3) | ||
local prl = nn.ParallelTable() | ||
prl:add(features) | ||
prl:add(nn.Identity()) | ||
temp:add(prl) | ||
temp:add(nnf.ROIPooling(7,7)) | ||
temp:add(nn.View(-1):setNumInputDims(3)) | ||
temp:add(classifier) | ||
end | ||
|
||
-- 2. Create Criterion | ||
criterion = nn.CrossEntropyCriterion() | ||
|
||
print('=> Model') | ||
print('Model:') | ||
print(model) | ||
|
||
print('=> Criterion') | ||
print('Criterion:') | ||
print(criterion) | ||
|
||
-- 3. If preloading option is set, preload weights from existing models appropriately | ||
-- If preloading option is set, preload weights from existing models appropriately | ||
if opt.retrain ~= 'none' then | ||
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) | ||
print('Loading model from file: ' .. opt.retrain); | ||
classifier = torch.load(opt.retrain) | ||
model = torch.load(opt.retrain) | ||
end | ||
|
||
-- 4. Convert model to CUDA | ||
print('==> Converting model to CUDA') | ||
model = model:cuda() | ||
criterion:cuda() | ||
|
||
collectgarbage() | ||
|
||
|
||
return model, criterion | ||
|
Oops, something went wrong.