-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
51 lines (42 loc) · 1.85 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
import pandas as pd
import time
from tqdm import tqdm
from utils import spilt_train_vaild_test,get_models
from dataset import ImageDataSet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
image_width=320
image_height=320
norm_mean = (0.63790344, 0.56811579, 0.5704457)
norm_std = (0.24307405, 0.2520139, 0.25256122)
train,vaild,test=spilt_train_vaild_test(contains_chusai_test=True,fusai=True)
test_transform = transforms.Compose([
transforms.Resize((image_width,image_height)),
transforms.ToTensor(),
transforms.Normalize(norm_mean,norm_std)
])
test_dataloader=DataLoader(ImageDataSet(test,test_transform), batch_size=1, shuffle=False, num_workers=32)
vaild_dataloader=DataLoader(ImageDataSet(vaild,test_transform), batch_size=1, shuffle=False, num_workers=32)
models_mapping=get_models()
print('模型数量:',len(models_mapping))
print('模型如下:','、'.join(list(models_mapping.keys())))
time.sleep(0.9)
with torch.no_grad():
result_all=[]
for images, labels, orders, image_id in tqdm(test_dataloader):
pred_array_all=np.zeros(137,dtype=np.float32)
total_acc_all = 0
for name,(model,acc) in models_mapping.items():
model.eval()
predict_label = model(Variable(images.reshape(-1,3,image_width,image_height)).to(device))
predict = predict_label.data.cpu().numpy().reshape(-1) * acc
pred_array_all += predict
total_acc_all += acc
pred_all = np.argmax(pred_array_all / total_acc_all)
result_all.append({'image_id':image_id[0],'category_id':pred_all})
submit=pd.DataFrame(result_all)
submit.to_csv('submit.csv',index=False)