본문 바로가기
computer vision/dna_study

pytorch 저장한 모델을 불러오고 testset 확인하기

by dohunNewte 2023. 3. 17.
반응형
1.학습된 모델의 가중치를 불러와서 weights변수에 넣어줍니다.
 
import torch
state_dict = torch.load('resnet50.pth')
best_epoch = state_dict['epoch']
best_test_acc = state_dict['test_acc']
weights = state_dict['net']

print(f'최종적으로 {best_epoch}번째 에포크에서 test셋 기준으로 {best_test_acc}를 달성하였습니다.')#변수로 바꾸기기

 

2. 학습한 가중치를 모델에 불러와줍니다.(finetuning)

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu' # device 배정
torch.manual_seed(42)
if device == 'cuda':
  torch.cuda.manual_seed_all(42)
device
 
from torchvision import models # 모델 라이브러리 함수

resnet_50 = models.resnet50(pretrained=False).to(device) # 선행학습 여부 , finetunig한 부분이 있으니까까

# finetuning
import torch.nn as nn # 파이토치 뉴럴네트워크 layer 라이브러리
resnet_50.fc = nn.Linear(resnet_50.fc.in_features, 3).to(device)
#학습한 가중치 적용완료료
resnet_50.load_state_dict(weights)
 
3. test셋의 최종 성능 확인하기
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from dataset import Custom_dataset as C

root_path = '/content/drive/MyDrive/Colab Notebooks/dna/week6/original-1'

test_transforms = A.Compose([
    A.Resize(224,224),
    A.Normalize(mean=(0.4850.4560.406), std=(0.2290.2240.225), max_pixel_value=255.0, always_apply=False, p=1.0), # 텐서타입은 안해줌
    ToTensorV2() # Normalize를 먼저하고 tensor화를 진행해야한다.
])
test_class = C(root_path=root_path, mode='test', transforms=test_transforms)
test_loader = DataLoader(test_class, batch_size=4, shuffle = False, num_workers=0)

 

model = resnet_50

model.eval()
test_loss = 0
correct = 0
criterion =  nn.CrossEntropyLoss(reduction='sum'#add all samples in a mini-batch
with torch.no_grad():
  for test_img, test_label in test_loader:
    test_img, test_label = test_img.to(device), test_label.to(device)
    output = resnet_50(test_img) #모델에 입력
    loss = criterion(output, test_label)
    test_loss +=  loss.item()
    pred = output.argmax(dim=1, keepdim=True# get the index of the max log-probability
    correct += pred.eq(test_label.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  test_loss, correct, len(test_loader.dataset),
  100. * correct / len(test_loader.dataset)))
728x90

댓글