from functools import partial
from collections import OrderedDict
%config InlineBackend.figure_format = 'retina'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision as tv
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
import requests
import io
def get_weights(bit_variant):
response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
response.raise_for_status()
return np.load(io.BytesIO(response.content))
weights = get_weights('BiT-M-R50x1')
# You could use other variants, such as R101x3 or R152x4 her
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups)
def conv1x1(cin, cout, stride=1, bias=False):
return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias)
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW"""
if conv_weights.ndim == 4:
conv_weights = np.transpose(conv_weights, [3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module):
# """
# Follows the implementation of "Identity Mappings in Deep Residual Networks" here:
# https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
# Except it puts the stride on 3x3 conv when available.
# """
def __init__(self, cin, cout=None, cmid=None, stride=1):
super().__init__()
cout = cout or cin
cmid = cmid or cout//4
self.gn1 = nn.GroupNorm(32, cin)
self.conv1 = conv1x1(cin, cmid)
self.gn2 = nn.GroupNorm(32, cmid)
self.conv2 = conv3x3(cmid, cmid, stride) # Original ResNetv2 has it on conv1!!
self.gn3 = nn.GroupNorm(32, cmid)
self.conv3 = conv1x1(cmid, cout)
self.relu = nn.ReLU(inplace=True)
if (stride != 1 or cin != cout):
# Projection also with pre-activation according to paper.
self.downsample = conv1x1(cin, cout, stride)
def forward(self, x):
# Conv'ed branch
out = self.relu(self.gn1(x))
# Residual branch
residual = x
if hasattr(self, 'downsample'):
residual = self.downsample(out)
# The first block has already applied pre-act before splitting, see Appendix.
out = self.conv1(out)
out = self.conv2(self.relu(self.gn2(out)))
out = self.conv3(self.relu(self.gn3(out)))
return out + residual
def load_from(self, weights, prefix=''):
with torch.no_grad():
self.conv1.weight.copy_(tf2th(weights[prefix + 'a/standardized_conv2d/kernel']))
self.conv2.weight.copy_(tf2th(weights[prefix + 'b/standardized_conv2d/kernel']))
self.conv3.weight.copy_(tf2th(weights[prefix + 'c/standardized_conv2d/kernel']))
self.gn1.weight.copy_(tf2th(weights[prefix + 'a/group_norm/gamma']))
self.gn2.weight.copy_(tf2th(weights[prefix + 'b/group_norm/gamma']))
self.gn3.weight.copy_(tf2th(weights[prefix + 'c/group_norm/gamma']))
self.gn1.bias.copy_(tf2th(weights[prefix + 'a/group_norm/beta']))
self.gn2.bias.copy_(tf2th(weights[prefix + 'b/group_norm/beta']))
self.gn3.bias.copy_(tf2th(weights[prefix + 'c/group_norm/beta']))
if hasattr(self, 'downsample'):
self.downsample.weight.copy_(tf2th(weights[prefix + 'a/proj/standardized_conv2d/kernel']))
return self
class ResNetV2(nn.Module):
BLOCK_UNITS = {
'r50': [3, 4, 6, 3],
'r101': [3, 4, 23, 3],
'r152': [3, 8, 36, 3],
}
def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
super().__init__()
wf = width_factor # shortcut 'cause we'll use it a lot.
self.root = nn.Sequential(OrderedDict([
('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
('padp', nn.ConstantPad2d(1, 0)),
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
# The following is subtly not the same!
#('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
self.body = nn.Sequential(f([
('block1', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin= 64*wf, cout=256*wf, cmid=64*wf))] +
[(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
))),
('block2', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
))),
('block3', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin= 512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
))),
('block4', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
))),
]))
self.zero_head = zero_head
self.head = nn.Sequential(OrderedDict([
('gn', nn.GroupNorm(32, 2048*wf)),
('relu', nn.ReLU(inplace=True)),
('avg', nn.AdaptiveAvgPool2d(output_size=1)),
('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
]))
def forward(self, x):
x = self.head(self.body(self.root(x)))
assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
return x[...,0,0]
def load_from(self, weights, prefix='resnet/'):
with torch.no_grad():
self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
if self.zero_head:
nn.init.zeros_(self.head.conv.weight)
nn.init.zeros_(self.head.conv.bias)
else:
self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for bname, block in self.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
return self
from IPython.display import HTML, display
def progress(value, max=100):
return HTML("""
{value}
""".format(value=value, max=max))
def stairs(s, v, *svs):
""" Implements a typical "stairs" schedule for learning-rates.
Best explained by example:
stairs(s, 0.1, 10, 0.01, 20, 0.001)
will return 0.1 if s<10, 0.01 if 10<=s<20, and 0.001 if 20<=s
"""
for s0, v0 in zip(svs[::2], svs[1::2]):
if s < s0:
break
v = v0
return v
def rampup(s, peak_s, peak_lr):
if s < peak_s: # Warmup
return s/peak_s * peak_lr
else:
return peak_lr
def schedule(s):
step_lr = stairs(s, 3e-3, 200, 3e-4, 300, 3e-5, 400, 3e-6, 500, None)
return rampup(s, 100, step_lr)
import PIL
preprocess_train = tv.transforms.Compose([
tv.transforms.Resize((160, 160), interpolation=PIL.Image.BILINEAR), # It's the default, just being explicit for the reader.
tv.transforms.RandomCrop((128, 128)),
tv.transforms.RandomHorizontalFlip(),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Get data into [-1, 1]
])
preprocess_eval = tv.transforms.Compose([
tv.transforms.Resize((128, 128), interpolation=PIL.Image.BILINEAR),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = tv.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess_train)
testset = tv.datasets.CIFAR10(root='./data', train=False, download=True, transform=preprocess_eval)
weights_cifar10 = get_weights('BiT-M-R50x1-CIFAR10')
model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=10) # NOTE: No new head.
model.load_from(weights_cifar10)
model.to(device);
def eval_cifar10(model, bs=100, progressbar=True):
loader_test = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=2)
model.eval()
if progressbar is True:
progressbar = display(progress(0, len(loader_test)), display_id=True)
preds = []
with torch.no_grad():
for i, (x, t) in enumerate(loader_test):
x, t = x.to(device), t.numpy()
logits = model(x)
_, y = torch.max(logits.data, 1)
preds.extend(y.cpu().numpy() == t)
progressbar.update(progress(i+1, len(loader_test)))
return np.mean(preds)
print("Expected: 97.61%")
print(f"Accuracy: {eval_cifar10(model):.2%}")
'논문 > vision_transformer' 카테고리의 다른 글
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows_2021 논문정리 (0) | 2023.09.29 |
---|---|
AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE 논문정리 (0) | 2023.08.04 |
댓글