반응형
1.학습된 모델의 가중치를 불러와서 weights변수에 넣어줍니다.
import torch
state_dict = torch.load('resnet50.pth')
best_epoch = state_dict['epoch']
best_test_acc = state_dict['test_acc']
weights = state_dict['net']
print(f'최종적으로 {best_epoch}번째 에포크에서 test셋 기준으로 {best_test_acc}를 달성하였습니다.')#변수로 바꾸기기
2. 학습한 가중치를 모델에 불러와줍니다.(finetuning)
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu' # device 배정
torch.manual_seed(42)
if device == 'cuda':
torch.cuda.manual_seed_all(42)
device
from torchvision import models # 모델 라이브러리 함수
resnet_50 = models.resnet50(pretrained=False).to(device) # 선행학습 여부 , finetunig한 부분이 있으니까까
# finetuning
import torch.nn as nn # 파이토치 뉴럴네트워크 layer 라이브러리
resnet_50.fc = nn.Linear(resnet_50.fc.in_features, 3).to(device)
#학습한 가중치 적용완료료
resnet_50.load_state_dict(weights)
3. test셋의 최종 성능 확인하기
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from dataset import Custom_dataset as C
root_path = '/content/drive/MyDrive/Colab Notebooks/dna/week6/original-1'
test_transforms = A.Compose([
A.Resize(224,224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=False, p=1.0), # 텐서타입은 안해줌
ToTensorV2() # Normalize를 먼저하고 tensor화를 진행해야한다.
])
test_class = C(root_path=root_path, mode='test', transforms=test_transforms)
test_loader = DataLoader(test_class, batch_size=4, shuffle = False, num_workers=0)
model = resnet_50
model.eval()
test_loss = 0
correct = 0
criterion = nn.CrossEntropyLoss(reduction='sum') #add all samples in a mini-batch
with torch.no_grad():
for test_img, test_label in test_loader:
test_img, test_label = test_img.to(device), test_label.to(device)
output = resnet_50(test_img) #모델에 입력
loss = criterion(output, test_label)
test_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(test_label.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
728x90
'computer vision > dna_study' 카테고리의 다른 글
sementic segmentation에서 multi class일때 cross entropy (0) | 2023.04.02 |
---|---|
pytorch 모델을 불러와서 학습시키고 모델 저장 (1) | 2023.03.17 |
pytorch) Augmentation이 적용된 이미지를 시각화하기 (0) | 2023.03.17 |
python albumentation 라이브러리 설명 (0) | 2023.03.17 |
pytorch 데이터셋 augmentation(transformer 적용) (0) | 2023.03.17 |
댓글