add ap
This commit is contained in:
parent
dd6e9320af
commit
16bfc22fe7
89
src/active_grasp/active_perception.py
Normal file
89
src/active_grasp/active_perception.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
path = os.path.abspath(__file__)
|
||||||
|
for i in range(2):
|
||||||
|
path = os.path.dirname(path)
|
||||||
|
PROJECT_ROOT = path
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
from active_perception.configs.config import ConfigManager
|
||||||
|
from active_perception.modules.pipeline import Pipeline
|
||||||
|
|
||||||
|
class InferenceEngine():
|
||||||
|
RESULTS_DIR_NAME: str = 'results'
|
||||||
|
LOG_DIR_NAME: str = 'log'
|
||||||
|
|
||||||
|
def __init__(self, config_path):
|
||||||
|
''' Config Manager '''
|
||||||
|
ConfigManager.load_config_with(config_path)
|
||||||
|
|
||||||
|
''' Pytorch Seed '''
|
||||||
|
seed = ConfigManager.get("settings", "general", "seed")
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
''' Pipeline '''
|
||||||
|
# self.pipeline_config = {'pts_encoder': 'pointnet', 'view_finder': 'gradient_field'}
|
||||||
|
self.pipeline_config = ConfigManager.get("settings", "pipeline")
|
||||||
|
self.device = ConfigManager.get("settings", "general", "device")
|
||||||
|
self.pipeline = Pipeline(self.pipeline_config)
|
||||||
|
self.parallel = ConfigManager.get("settings","general","parallel")
|
||||||
|
if self.parallel and self.device == "cuda":
|
||||||
|
self.pipeline = torch.nn.DataParallel(self.pipeline)
|
||||||
|
self.pipeline = self.pipeline.to(self.device)
|
||||||
|
|
||||||
|
''' Experiment '''
|
||||||
|
# self.model_path = '~/Downloads/full_149_241009.pth'
|
||||||
|
self.model_path = ConfigManager.get("settings", "experiment", "model_path")
|
||||||
|
self.load(self.model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
state_dict = torch.load(path)
|
||||||
|
if self.parallel:
|
||||||
|
self.pipeline.module.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
self.pipeline.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def inference(self, data):
|
||||||
|
self.pipeline.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.pipeline(data, Pipeline.TEST_MODE)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
''' Load Configs '''
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default=PROJECT_ROOT+"/active_grasp/active_perception/configs/local_inference_config.yaml")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
''' Initialize Test Data '''
|
||||||
|
test_scene = torch.rand(1, 1024, 3).to("cuda:0")
|
||||||
|
test_target = torch.rand(1, 1024, 3).to("cuda:0")
|
||||||
|
test_delta_rot_6d = torch.rand(1, 6).to("cuda:0")
|
||||||
|
a = test_delta_rot_6d[:, :3]
|
||||||
|
b = test_delta_rot_6d[:, 3:]
|
||||||
|
a_norm = a / a.norm(dim=1, keepdim=True)
|
||||||
|
b_norm = b / b.norm(dim=1, keepdim=True)
|
||||||
|
normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1)
|
||||||
|
test_data = {
|
||||||
|
'target_pts': test_target,
|
||||||
|
'scene_pts': test_scene,
|
||||||
|
}
|
||||||
|
|
||||||
|
''' Inference '''
|
||||||
|
infenrence_engine = InferenceEngine(args.config)
|
||||||
|
output = infenrence_engine.inference(test_data)
|
||||||
|
print(output.keys())
|
||||||
|
print(output['estimated_delta_rot_6d'])
|
0
src/active_grasp/active_perception/__init__.py
Normal file
0
src/active_grasp/active_perception/__init__.py
Normal file
7
src/active_grasp/active_perception/annotations/external_module.py
Executable file
7
src/active_grasp/active_perception/annotations/external_module.py
Executable file
@ -0,0 +1,7 @@
|
|||||||
|
EXTERNAL_FREEZE_MODULES = set()
|
||||||
|
|
||||||
|
def external_freeze(cls):
|
||||||
|
if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')):
|
||||||
|
raise TypeError(f"external module <{cls.__name__}> must implement a 'load' method")
|
||||||
|
EXTERNAL_FREEZE_MODULES.add(cls)
|
||||||
|
return cls
|
8
src/active_grasp/active_perception/annotations/singleton.py
Executable file
8
src/active_grasp/active_perception/annotations/singleton.py
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
def singleton(cls):
|
||||||
|
instances = {}
|
||||||
|
def get_instance(*args, **kwargs):
|
||||||
|
if cls not in instances:
|
||||||
|
instances[cls] = cls(*args, **kwargs)
|
||||||
|
return instances[cls]
|
||||||
|
return get_instance
|
34
src/active_grasp/active_perception/annotations/stereotype.py
Executable file
34
src/active_grasp/active_perception/annotations/stereotype.py
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
# --- Classes --- #
|
||||||
|
|
||||||
|
def dataset():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def module():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pipeline():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def factory():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- Functions --- #
|
||||||
|
|
||||||
|
evaluation_methods = {}
|
||||||
|
def evaluation_method(eval_type):
|
||||||
|
def decorator(func):
|
||||||
|
evaluation_methods[eval_type] = func
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def loss_function():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main --- #
|
||||||
|
|
||||||
|
|
74
src/active_grasp/active_perception/configs/config.py
Executable file
74
src/active_grasp/active_perception/configs/config.py
Executable file
@ -0,0 +1,74 @@
|
|||||||
|
import argparse
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
config = None
|
||||||
|
config_path = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get(*args):
|
||||||
|
result = ConfigManager.config
|
||||||
|
for arg in args:
|
||||||
|
result = result[arg]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config_with(config_file_path):
|
||||||
|
ConfigManager.config_path = config_file_path
|
||||||
|
if not os.path.exists(ConfigManager.config_path):
|
||||||
|
raise ValueError(f"Config file <{config_file_path}> does not exist")
|
||||||
|
with open(config_file_path, 'r') as file:
|
||||||
|
ConfigManager.config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backup_config_to(target_config_dir, file_name, prefix="config"):
|
||||||
|
file_name = f"{prefix}_{file_name}.yaml"
|
||||||
|
target_config_file_path = str(os.path.join(target_config_dir, file_name))
|
||||||
|
shutil.copy(ConfigManager.config_path, target_config_file_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--config', type=str, default='', help='config file path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.config:
|
||||||
|
ConfigManager.load_config_with(args.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_config(key: str = None, group: dict = None, level=0):
|
||||||
|
table_size = 80
|
||||||
|
if key and group:
|
||||||
|
value = group[key]
|
||||||
|
if type(value) is dict:
|
||||||
|
print("\t" * level + f"+-{key}:")
|
||||||
|
for k in value:
|
||||||
|
ConfigManager.print_config(k, value, level=level + 1)
|
||||||
|
else:
|
||||||
|
print("\t" * level + f"| {key}: {value}")
|
||||||
|
elif key:
|
||||||
|
ConfigManager.print_config(key, ConfigManager.config, level=level)
|
||||||
|
else:
|
||||||
|
print("+" + "-" * table_size + "+")
|
||||||
|
print(f"| Configurations in <{ConfigManager.config_path}>:")
|
||||||
|
print("+" + "-" * table_size + "+")
|
||||||
|
for key in ConfigManager.config:
|
||||||
|
ConfigManager.print_config(key, level=level + 1)
|
||||||
|
print("+" + "-" * table_size + "+")
|
||||||
|
|
||||||
|
|
||||||
|
''' ------------ Debug ------------ '''
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_args = ['--config', 'local_train_config.yaml']
|
||||||
|
test_parser = argparse.ArgumentParser()
|
||||||
|
test_parser.add_argument('--config', type=str, default='', help='config file path')
|
||||||
|
test_args = test_parser.parse_args(test_args)
|
||||||
|
if test_args.config:
|
||||||
|
ConfigManager.load_config_with(test_args.config)
|
||||||
|
ConfigManager.print_config()
|
||||||
|
print()
|
||||||
|
pipeline = ConfigManager.get('settings', 'train', 'batch_size')
|
||||||
|
ConfigManager.print_config('settings')
|
||||||
|
print(pipeline)
|
66
src/active_grasp/active_perception/configs/local_inference_config.yaml
Executable file
66
src/active_grasp/active_perception/configs/local_inference_config.yaml
Executable file
@ -0,0 +1,66 @@
|
|||||||
|
# Train config file
|
||||||
|
|
||||||
|
settings:
|
||||||
|
general:
|
||||||
|
seed: 0
|
||||||
|
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||||
|
device: cuda
|
||||||
|
test_dir: ""
|
||||||
|
print: True
|
||||||
|
parallel: True
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
name: test_inference
|
||||||
|
root_dir: "experiments"
|
||||||
|
model_path: "/home/zhengxiao-han/Downloads/full_149_241009.pth"
|
||||||
|
use_cache: True
|
||||||
|
small_batch_overfit: False
|
||||||
|
|
||||||
|
test:
|
||||||
|
batch_size: 96
|
||||||
|
dataset_list:
|
||||||
|
- name: synthetic_test_sample
|
||||||
|
source: nbv1
|
||||||
|
data_type: sample
|
||||||
|
synthetic: True
|
||||||
|
ratio: 1.0
|
||||||
|
batch_size: 96
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
results:
|
||||||
|
save_data_keys: ["target_name","src_rot_mat"]
|
||||||
|
save_output_keys: ["in_process_sample"]
|
||||||
|
|
||||||
|
pipeline: # module_type: name
|
||||||
|
pts_encoder: pointnet
|
||||||
|
view_finder: gradient_field
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
general:
|
||||||
|
data_dir: "/mnt/d/Datasets"
|
||||||
|
score_limit: 0.3
|
||||||
|
target_pts_num: 1024
|
||||||
|
scene_pts_num: 16384
|
||||||
|
canonical: False
|
||||||
|
rgb_feat_cache: True
|
||||||
|
|
||||||
|
|
||||||
|
modules:
|
||||||
|
general:
|
||||||
|
pts_channels: 3
|
||||||
|
feature_dim: 1024
|
||||||
|
per_point_feature: False
|
||||||
|
pts_encoder:
|
||||||
|
pointnet:
|
||||||
|
pointnet++:
|
||||||
|
params_name: light
|
||||||
|
view_finder:
|
||||||
|
gradient_field:
|
||||||
|
pose_mode: rot_matrix
|
||||||
|
regression_head: Rx_Ry
|
||||||
|
sample_mode: ode
|
||||||
|
sample_repeat: 50
|
||||||
|
sampling_steps: 500
|
||||||
|
sde_mode: ve
|
||||||
|
rgb_encoder:
|
||||||
|
dinov2:
|
0
src/active_grasp/active_perception/modules/__init__.py
Executable file
0
src/active_grasp/active_perception/modules/__init__.py
Executable file
7
src/active_grasp/active_perception/modules/func_lib/__init__.py
Executable file
7
src/active_grasp/active_perception/modules/func_lib/__init__.py
Executable file
@ -0,0 +1,7 @@
|
|||||||
|
from modules.func_lib.samplers import (
|
||||||
|
cond_pc_sampler,
|
||||||
|
cond_ode_sampler
|
||||||
|
)
|
||||||
|
from modules.func_lib.sde import (
|
||||||
|
init_sde
|
||||||
|
)
|
282
src/active_grasp/active_perception/modules/func_lib/samplers.py
Executable file
282
src/active_grasp/active_perception/modules/func_lib/samplers.py
Executable file
@ -0,0 +1,282 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from scipy import integrate
|
||||||
|
from utils.pose_util import PoseUtil
|
||||||
|
|
||||||
|
|
||||||
|
def global_prior_likelihood(z, sigma_max):
|
||||||
|
"""The likelihood of a Gaussian distribution with mean zero and
|
||||||
|
standard deviation sigma."""
|
||||||
|
# z: [bs, pose_dim]
|
||||||
|
shape = z.shape
|
||||||
|
N = np.prod(shape[1:]) # pose_dim
|
||||||
|
return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
def cond_ode_likelihood(
|
||||||
|
score_model,
|
||||||
|
data,
|
||||||
|
prior,
|
||||||
|
sde_coeff,
|
||||||
|
marginal_prob_fn,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=1e-5,
|
||||||
|
device='cuda',
|
||||||
|
eps=1e-5,
|
||||||
|
num_steps=None,
|
||||||
|
pose_mode='quat_wxyz',
|
||||||
|
init_x=None,
|
||||||
|
):
|
||||||
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||||
|
batch_size = data['pts'].shape[0]
|
||||||
|
epsilon = prior((batch_size, pose_dim)).to(device)
|
||||||
|
init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
|
||||||
|
shape = init_x.shape
|
||||||
|
init_logp = np.zeros((shape[0],)) # [bs]
|
||||||
|
init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
|
||||||
|
|
||||||
|
def score_eval_wrapper(data):
|
||||||
|
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||||
|
with torch.no_grad():
|
||||||
|
score = score_model(data)
|
||||||
|
return score.cpu().numpy().reshape((-1,))
|
||||||
|
|
||||||
|
def divergence_eval(data, epsilon):
|
||||||
|
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
|
||||||
|
# save ckpt of sampled_pose
|
||||||
|
origin_sampled_pose = data['sampled_pose'].clone()
|
||||||
|
with torch.enable_grad():
|
||||||
|
# make sampled_pose differentiable
|
||||||
|
data['sampled_pose'].requires_grad_(True)
|
||||||
|
score = score_model(data)
|
||||||
|
score_energy = torch.sum(score * epsilon) # [, ]
|
||||||
|
grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
|
||||||
|
# reset sampled_pose
|
||||||
|
data['sampled_pose'] = origin_sampled_pose
|
||||||
|
return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
|
||||||
|
|
||||||
|
def divergence_eval_wrapper(data):
|
||||||
|
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Compute likelihood.
|
||||||
|
div = divergence_eval(data, epsilon) # [bs, 1]
|
||||||
|
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
|
||||||
|
|
||||||
|
def ode_func(t, inp):
|
||||||
|
"""The ODE function for use by the ODE solver."""
|
||||||
|
# split x, logp from inp
|
||||||
|
x = inp[:-shape[0]]
|
||||||
|
# calc x-grad
|
||||||
|
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||||
|
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||||
|
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||||
|
drift = drift.cpu().numpy()
|
||||||
|
diffusion = diffusion.cpu().numpy()
|
||||||
|
data['sampled_pose'] = x
|
||||||
|
data['t'] = time_steps
|
||||||
|
x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||||
|
# calc logp-grad
|
||||||
|
logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data)
|
||||||
|
# concat curr grad
|
||||||
|
return np.concatenate([x_grad, logp_grad], axis=0)
|
||||||
|
|
||||||
|
# Run the black-box ODE solver, note the
|
||||||
|
res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
|
||||||
|
zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
|
||||||
|
z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
|
||||||
|
delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
|
||||||
|
_, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
|
||||||
|
prior_logp = global_prior_likelihood(z, sigma_max)
|
||||||
|
log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
|
||||||
|
return z, log_likelihoods
|
||||||
|
|
||||||
|
|
||||||
|
def cond_pc_sampler(
|
||||||
|
score_model,
|
||||||
|
data,
|
||||||
|
prior,
|
||||||
|
sde_coeff,
|
||||||
|
num_steps=500,
|
||||||
|
snr=0.16,
|
||||||
|
device='cuda',
|
||||||
|
eps=1e-5,
|
||||||
|
pose_mode='quat_wxyz',
|
||||||
|
init_x=None,
|
||||||
|
):
|
||||||
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||||
|
batch_size = data['target_pts_feat'].shape[0]
|
||||||
|
init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
|
||||||
|
time_steps = torch.linspace(1., eps, num_steps, device=device)
|
||||||
|
step_size = time_steps[0] - time_steps[1]
|
||||||
|
noise_norm = np.sqrt(pose_dim)
|
||||||
|
x = init_x
|
||||||
|
poses = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for time_step in time_steps:
|
||||||
|
batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
|
||||||
|
# Corrector step (Langevin MCMC)
|
||||||
|
data['sampled_pose'] = x
|
||||||
|
data['t'] = batch_time_step
|
||||||
|
grad = score_model(data)
|
||||||
|
grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
|
||||||
|
langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
|
||||||
|
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
|
||||||
|
|
||||||
|
# normalisation
|
||||||
|
if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
|
||||||
|
# quat, should be normalised
|
||||||
|
x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
|
||||||
|
elif pose_mode == 'euler_xyz':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# rotation(x axis, y axis), should be normalised
|
||||||
|
x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
|
||||||
|
x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Predictor step (Euler-Maruyama)
|
||||||
|
drift, diffusion = sde_coeff(batch_time_step)
|
||||||
|
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||||
|
mean_x = x + drift * step_size
|
||||||
|
x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
|
||||||
|
|
||||||
|
# normalisation
|
||||||
|
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||||
|
poses.append(x.unsqueeze(0))
|
||||||
|
|
||||||
|
xs = torch.cat(poses, dim=0)
|
||||||
|
xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
|
||||||
|
mean_x[:, -3:] += data['pts_center']
|
||||||
|
mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode)
|
||||||
|
# The last step does not include any noise
|
||||||
|
return xs.permute(1, 0, 2), mean_x
|
||||||
|
|
||||||
|
|
||||||
|
def cond_ode_sampler(
|
||||||
|
score_model,
|
||||||
|
data,
|
||||||
|
prior,
|
||||||
|
sde_coeff,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=1e-5,
|
||||||
|
device='cuda',
|
||||||
|
eps=1e-5,
|
||||||
|
T=1.0,
|
||||||
|
num_steps=None,
|
||||||
|
pose_mode='quat_wxyz',
|
||||||
|
denoise=True,
|
||||||
|
init_x=None,
|
||||||
|
):
|
||||||
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||||
|
batch_size = data['target_feat'].shape[0]
|
||||||
|
init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim),
|
||||||
|
T=T).to(device)
|
||||||
|
shape = init_x.shape
|
||||||
|
|
||||||
|
def score_eval_wrapper(data):
|
||||||
|
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||||
|
with torch.no_grad():
|
||||||
|
score = score_model(data)
|
||||||
|
return score.cpu().numpy().reshape((-1,))
|
||||||
|
|
||||||
|
def ode_func(t, x):
|
||||||
|
"""The ODE function for use by the ODE solver."""
|
||||||
|
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||||
|
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||||
|
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||||
|
drift = drift.cpu().numpy()
|
||||||
|
diffusion = diffusion.cpu().numpy()
|
||||||
|
data['sampled_pose'] = x
|
||||||
|
data['t'] = time_steps
|
||||||
|
return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||||
|
|
||||||
|
# Run the black-box ODE solver, note the
|
||||||
|
t_eval = None
|
||||||
|
if num_steps is not None:
|
||||||
|
# num_steps, from T -> eps
|
||||||
|
t_eval = np.linspace(T, eps, num_steps)
|
||||||
|
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45',
|
||||||
|
t_eval=t_eval)
|
||||||
|
xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
|
||||||
|
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||||
|
# denoise, using the predictor step in P-C sampler
|
||||||
|
if denoise:
|
||||||
|
# Reverse diffusion predictor for denoising
|
||||||
|
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||||
|
drift, diffusion = sde_coeff(vec_eps)
|
||||||
|
data['sampled_pose'] = x.float()
|
||||||
|
data['t'] = vec_eps
|
||||||
|
grad = score_model(data)
|
||||||
|
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||||
|
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||||
|
x = mean_x
|
||||||
|
|
||||||
|
num_steps = xs.shape[0]
|
||||||
|
xs = xs.reshape(batch_size * num_steps, -1)
|
||||||
|
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||||
|
xs = xs.reshape(num_steps, batch_size, -1)
|
||||||
|
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||||
|
return xs.permute(1, 0, 2), x
|
||||||
|
|
||||||
|
|
||||||
|
def cond_edm_sampler(
|
||||||
|
decoder_model, data, prior_fn, randn_like=torch.randn_like,
|
||||||
|
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
||||||
|
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
||||||
|
pose_mode='quat_wxyz', device='cuda'
|
||||||
|
):
|
||||||
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||||
|
batch_size = data['pts'].shape[0]
|
||||||
|
latents = prior_fn((batch_size, pose_dim)).to(device)
|
||||||
|
|
||||||
|
# Time step discretion. note that sigma and t is interchangeable
|
||||||
|
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
||||||
|
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
||||||
|
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||||
|
t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
||||||
|
|
||||||
|
def decoder_wrapper(decoder, data, x, t):
|
||||||
|
# save temp
|
||||||
|
x_, t_ = data['sampled_pose'], data['t']
|
||||||
|
# init data
|
||||||
|
data['sampled_pose'], data['t'] = x, t
|
||||||
|
# denoise
|
||||||
|
data, denoised = decoder(data)
|
||||||
|
# recover data
|
||||||
|
data['sampled_pose'], data['t'] = x_, t_
|
||||||
|
return denoised.to(torch.float64)
|
||||||
|
|
||||||
|
# Main sampling loop.
|
||||||
|
x_next = latents.to(torch.float64) * t_steps[0]
|
||||||
|
xs = []
|
||||||
|
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
||||||
|
x_cur = x_next
|
||||||
|
|
||||||
|
# Increase noise temporarily.
|
||||||
|
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
||||||
|
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
|
||||||
|
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
||||||
|
|
||||||
|
# Euler step.
|
||||||
|
denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
|
||||||
|
d_cur = (x_hat - denoised) / t_hat
|
||||||
|
x_next = x_hat + (t_next - t_hat) * d_cur
|
||||||
|
|
||||||
|
# Apply 2nd order correction.
|
||||||
|
if i < num_steps - 1:
|
||||||
|
denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
|
||||||
|
d_prime = (x_next - denoised) / t_next
|
||||||
|
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
||||||
|
xs.append(x_next.unsqueeze(0))
|
||||||
|
|
||||||
|
xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
|
||||||
|
x = xs[-1] # [bs, pose_dim]
|
||||||
|
|
||||||
|
# post-processing
|
||||||
|
xs = xs.reshape(batch_size * num_steps, -1)
|
||||||
|
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||||
|
xs = xs.reshape(num_steps, batch_size, -1)
|
||||||
|
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||||
|
return xs.permute(1, 0, 2), x
|
121
src/active_grasp/active_perception/modules/func_lib/sde.py
Executable file
121
src/active_grasp/active_perception/modules/func_lib/sde.py
Executable file
@ -0,0 +1,121 @@
|
|||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ----- VE SDE -----
|
||||||
|
# ------------------
|
||||||
|
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
|
||||||
|
std = sigma_min * (sigma_max / sigma_min) ** t
|
||||||
|
mean = x
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def ve_sde(t, sigma_min=0.01, sigma_max=90):
|
||||||
|
sigma = sigma_min * (sigma_max / sigma_min) ** t
|
||||||
|
drift_coeff = torch.tensor(0)
|
||||||
|
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
|
||||||
|
return drift_coeff, diffusion_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
|
||||||
|
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
return torch.randn(*shape) * sigma_max_prior
|
||||||
|
|
||||||
|
|
||||||
|
# ----- VP SDE -----
|
||||||
|
# ------------------
|
||||||
|
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
|
||||||
|
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||||
|
mean = torch.exp(log_mean_coeff) * x
|
||||||
|
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def vp_sde(t, beta_0=0.1, beta_1=20):
|
||||||
|
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||||
|
drift_coeff = -0.5 * beta_t
|
||||||
|
diffusion_coeff = torch.sqrt(beta_t)
|
||||||
|
return drift_coeff, diffusion_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def vp_prior(shape, beta_0=0.1, beta_1=20):
|
||||||
|
return torch.randn(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- sub-VP SDE -----
|
||||||
|
# ----------------------
|
||||||
|
def subvp_marginal_prob(x, t, beta_0, beta_1):
|
||||||
|
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||||
|
mean = torch.exp(log_mean_coeff) * x
|
||||||
|
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def subvp_sde(t, beta_0, beta_1):
|
||||||
|
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||||
|
drift_coeff = -0.5 * beta_t
|
||||||
|
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
|
||||||
|
diffusion_coeff = torch.sqrt(beta_t * discount)
|
||||||
|
return drift_coeff, diffusion_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def subvp_prior(shape, beta_0=0.1, beta_1=20):
|
||||||
|
return torch.randn(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- EDM SDE -----
|
||||||
|
# ------------------
|
||||||
|
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
|
||||||
|
std = t
|
||||||
|
mean = x
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def edm_sde(t, sigma_min=0.002, sigma_max=80):
|
||||||
|
drift_coeff = torch.tensor(0)
|
||||||
|
diffusion_coeff = torch.sqrt(2 * t)
|
||||||
|
return drift_coeff, diffusion_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
|
||||||
|
return torch.randn(*shape) * sigma_max
|
||||||
|
|
||||||
|
|
||||||
|
def init_sde(sde_mode):
|
||||||
|
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
|
||||||
|
if sde_mode == 'edm':
|
||||||
|
sigma_min = 0.002
|
||||||
|
sigma_max = 80
|
||||||
|
eps = 0.002
|
||||||
|
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
T = sigma_max
|
||||||
|
elif sde_mode == 've':
|
||||||
|
sigma_min = 0.01
|
||||||
|
sigma_max = 50
|
||||||
|
eps = 1e-5
|
||||||
|
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
T = 1.0
|
||||||
|
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||||
|
elif sde_mode == 'vp':
|
||||||
|
beta_0 = 0.1
|
||||||
|
beta_1 = 20
|
||||||
|
eps = 1e-3
|
||||||
|
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
T = 1.0
|
||||||
|
elif sde_mode == 'subvp':
|
||||||
|
beta_0 = 0.1
|
||||||
|
beta_1 = 20
|
||||||
|
eps = 1e-3
|
||||||
|
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||||
|
T = 1.0
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return prior_fn, marginal_prob_fn, sde_fn, eps, T
|
4
src/active_grasp/active_perception/modules/module_lib/__init__.py
Executable file
4
src/active_grasp/active_perception/modules/module_lib/__init__.py
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
|
||||||
|
from modules.module_lib.linear import Linear
|
||||||
|
from modules.module_lib.position_embedding import PositionalEmbedding
|
||||||
|
from modules.module_lib.rot_head import RotHead
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
__version__ = "0.0.1"
|
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_name: str):
|
||||||
|
config_filename = config_name + ".yaml"
|
||||||
|
return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
|
||||||
|
|
||||||
|
|
||||||
|
dinov2_default_config = load_config("ssl_default_config")
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_merge_config(config_name: str):
|
||||||
|
default_config = OmegaConf.create(dinov2_default_config)
|
||||||
|
loaded_config = load_config(config_name)
|
||||||
|
return OmegaConf.merge(default_config, loaded_config)
|
@ -0,0 +1,6 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_base
|
||||||
|
patch_size: 14
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,9 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_base
|
||||||
|
patch_size: 14
|
||||||
|
num_register_tokens: 4
|
||||||
|
interpolate_antialias: true
|
||||||
|
interpolate_offset: 0.0
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,7 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_giant2
|
||||||
|
patch_size: 14
|
||||||
|
ffn_layer: swiglufused
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,10 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_giant2
|
||||||
|
patch_size: 14
|
||||||
|
ffn_layer: swiglufused
|
||||||
|
num_register_tokens: 4
|
||||||
|
interpolate_antialias: true
|
||||||
|
interpolate_offset: 0.0
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,6 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_large
|
||||||
|
patch_size: 14
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,9 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_large
|
||||||
|
patch_size: 14
|
||||||
|
num_register_tokens: 4
|
||||||
|
interpolate_antialias: true
|
||||||
|
interpolate_offset: 0.0
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,6 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_small
|
||||||
|
patch_size: 14
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,9 @@
|
|||||||
|
student:
|
||||||
|
arch: vit_small
|
||||||
|
patch_size: 14
|
||||||
|
num_register_tokens: 4
|
||||||
|
interpolate_antialias: true
|
||||||
|
interpolate_offset: 0.0
|
||||||
|
crops:
|
||||||
|
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,118 @@
|
|||||||
|
MODEL:
|
||||||
|
WEIGHTS: ''
|
||||||
|
compute_precision:
|
||||||
|
grad_scaler: true
|
||||||
|
teacher:
|
||||||
|
backbone:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp16
|
||||||
|
buffer_dtype: fp32
|
||||||
|
dino_head:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp16
|
||||||
|
buffer_dtype: fp32
|
||||||
|
ibot_head:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp16
|
||||||
|
buffer_dtype: fp32
|
||||||
|
student:
|
||||||
|
backbone:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp16
|
||||||
|
buffer_dtype: fp32
|
||||||
|
dino_head:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp32
|
||||||
|
buffer_dtype: fp32
|
||||||
|
ibot_head:
|
||||||
|
sharding_strategy: SHARD_GRAD_OP
|
||||||
|
mixed_precision:
|
||||||
|
param_dtype: fp16
|
||||||
|
reduce_dtype: fp32
|
||||||
|
buffer_dtype: fp32
|
||||||
|
dino:
|
||||||
|
loss_weight: 1.0
|
||||||
|
head_n_prototypes: 65536
|
||||||
|
head_bottleneck_dim: 256
|
||||||
|
head_nlayers: 3
|
||||||
|
head_hidden_dim: 2048
|
||||||
|
koleo_loss_weight: 0.1
|
||||||
|
ibot:
|
||||||
|
loss_weight: 1.0
|
||||||
|
mask_sample_probability: 0.5
|
||||||
|
mask_ratio_min_max:
|
||||||
|
- 0.1
|
||||||
|
- 0.5
|
||||||
|
separate_head: false
|
||||||
|
head_n_prototypes: 65536
|
||||||
|
head_bottleneck_dim: 256
|
||||||
|
head_nlayers: 3
|
||||||
|
head_hidden_dim: 2048
|
||||||
|
train:
|
||||||
|
batch_size_per_gpu: 64
|
||||||
|
dataset_path: ImageNet:split=TRAIN
|
||||||
|
output_dir: .
|
||||||
|
saveckp_freq: 20
|
||||||
|
seed: 0
|
||||||
|
num_workers: 10
|
||||||
|
OFFICIAL_EPOCH_LENGTH: 1250
|
||||||
|
cache_dataset: true
|
||||||
|
centering: "centering" # or "sinkhorn_knopp"
|
||||||
|
student:
|
||||||
|
arch: vit_large
|
||||||
|
patch_size: 16
|
||||||
|
drop_path_rate: 0.3
|
||||||
|
layerscale: 1.0e-05
|
||||||
|
drop_path_uniform: true
|
||||||
|
pretrained_weights: ''
|
||||||
|
ffn_layer: "mlp"
|
||||||
|
block_chunks: 0
|
||||||
|
qkv_bias: true
|
||||||
|
proj_bias: true
|
||||||
|
ffn_bias: true
|
||||||
|
num_register_tokens: 0
|
||||||
|
interpolate_antialias: false
|
||||||
|
interpolate_offset: 0.1
|
||||||
|
teacher:
|
||||||
|
momentum_teacher: 0.992
|
||||||
|
final_momentum_teacher: 1
|
||||||
|
warmup_teacher_temp: 0.04
|
||||||
|
teacher_temp: 0.07
|
||||||
|
warmup_teacher_temp_epochs: 30
|
||||||
|
optim:
|
||||||
|
epochs: 100
|
||||||
|
weight_decay: 0.04
|
||||||
|
weight_decay_end: 0.4
|
||||||
|
base_lr: 0.004 # learning rate for a batch size of 1024
|
||||||
|
lr: 0. # will be set after applying scaling rule
|
||||||
|
warmup_epochs: 10
|
||||||
|
min_lr: 1.0e-06
|
||||||
|
clip_grad: 3.0
|
||||||
|
freeze_last_layer_epochs: 1
|
||||||
|
scaling_rule: sqrt_wrt_1024
|
||||||
|
patch_embed_lr_mult: 0.2
|
||||||
|
layerwise_decay: 0.9
|
||||||
|
adamw_beta1: 0.9
|
||||||
|
adamw_beta2: 0.999
|
||||||
|
crops:
|
||||||
|
global_crops_scale:
|
||||||
|
- 0.32
|
||||||
|
- 1.0
|
||||||
|
local_crops_number: 8
|
||||||
|
local_crops_scale:
|
||||||
|
- 0.05
|
||||||
|
- 0.32
|
||||||
|
global_crops_size: 224
|
||||||
|
local_crops_size: 96
|
||||||
|
evaluation:
|
||||||
|
eval_period_iterations: 12500
|
@ -0,0 +1,26 @@
|
|||||||
|
dino:
|
||||||
|
head_n_prototypes: 131072
|
||||||
|
head_bottleneck_dim: 384
|
||||||
|
ibot:
|
||||||
|
separate_head: true
|
||||||
|
head_n_prototypes: 131072
|
||||||
|
train:
|
||||||
|
batch_size_per_gpu: 12
|
||||||
|
dataset_path: ImageNet22k
|
||||||
|
centering: sinkhorn_knopp
|
||||||
|
student:
|
||||||
|
arch: vit_giant2
|
||||||
|
patch_size: 14
|
||||||
|
drop_path_rate: 0.4
|
||||||
|
ffn_layer: swiglufused
|
||||||
|
block_chunks: 4
|
||||||
|
teacher:
|
||||||
|
momentum_teacher: 0.994
|
||||||
|
optim:
|
||||||
|
epochs: 500
|
||||||
|
weight_decay_end: 0.2
|
||||||
|
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
||||||
|
warmup_epochs: 80
|
||||||
|
layerwise_decay: 1.0
|
||||||
|
crops:
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,26 @@
|
|||||||
|
dino:
|
||||||
|
head_n_prototypes: 131072
|
||||||
|
head_bottleneck_dim: 384
|
||||||
|
ibot:
|
||||||
|
separate_head: true
|
||||||
|
head_n_prototypes: 131072
|
||||||
|
train:
|
||||||
|
batch_size_per_gpu: 32
|
||||||
|
dataset_path: ImageNet22k
|
||||||
|
centering: sinkhorn_knopp
|
||||||
|
student:
|
||||||
|
arch: vit_large
|
||||||
|
patch_size: 14
|
||||||
|
drop_path_rate: 0.4
|
||||||
|
ffn_layer: swiglufused
|
||||||
|
block_chunks: 4
|
||||||
|
teacher:
|
||||||
|
momentum_teacher: 0.994
|
||||||
|
optim:
|
||||||
|
epochs: 500
|
||||||
|
weight_decay_end: 0.2
|
||||||
|
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
||||||
|
warmup_epochs: 80
|
||||||
|
layerwise_decay: 1.0
|
||||||
|
crops:
|
||||||
|
local_crops_size: 98
|
@ -0,0 +1,6 @@
|
|||||||
|
# this corresponds to the default config
|
||||||
|
train:
|
||||||
|
dataset_path: ImageNet:split=TRAIN
|
||||||
|
batch_size_per_gpu: 64
|
||||||
|
student:
|
||||||
|
block_chunks: 4
|
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .adapters import DatasetWithEnumeratedTargets
|
||||||
|
from .loaders import make_data_loader, make_dataset, SamplerType
|
||||||
|
from .collate import collate_data_and_cast
|
||||||
|
from .masking import MaskingGenerator
|
||||||
|
from .augmentations import DataAugmentationDINO
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetWithEnumeratedTargets(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self._dataset = dataset
|
||||||
|
|
||||||
|
def get_image_data(self, index: int) -> bytes:
|
||||||
|
return self._dataset.get_image_data(index)
|
||||||
|
|
||||||
|
def get_target(self, index: int) -> Tuple[Any, int]:
|
||||||
|
target = self._dataset.get_target(index)
|
||||||
|
return (index, target)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
||||||
|
image, target = self._dataset[index]
|
||||||
|
target = index if target is None else target
|
||||||
|
return image, (index, target)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._dataset)
|
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from .transforms import (
|
||||||
|
GaussianBlur,
|
||||||
|
make_normalize_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
class DataAugmentationDINO(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
global_crops_scale,
|
||||||
|
local_crops_scale,
|
||||||
|
local_crops_number,
|
||||||
|
global_crops_size=224,
|
||||||
|
local_crops_size=96,
|
||||||
|
):
|
||||||
|
self.global_crops_scale = global_crops_scale
|
||||||
|
self.local_crops_scale = local_crops_scale
|
||||||
|
self.local_crops_number = local_crops_number
|
||||||
|
self.global_crops_size = global_crops_size
|
||||||
|
self.local_crops_size = local_crops_size
|
||||||
|
|
||||||
|
logger.info("###################################")
|
||||||
|
logger.info("Using data augmentation parameters:")
|
||||||
|
logger.info(f"global_crops_scale: {global_crops_scale}")
|
||||||
|
logger.info(f"local_crops_scale: {local_crops_scale}")
|
||||||
|
logger.info(f"local_crops_number: {local_crops_number}")
|
||||||
|
logger.info(f"global_crops_size: {global_crops_size}")
|
||||||
|
logger.info(f"local_crops_size: {local_crops_size}")
|
||||||
|
logger.info("###################################")
|
||||||
|
|
||||||
|
# random resized crop and flip
|
||||||
|
self.geometric_augmentation_global = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomResizedCrop(
|
||||||
|
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||||
|
),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.geometric_augmentation_local = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomResizedCrop(
|
||||||
|
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||||
|
),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# color distorsions / blurring
|
||||||
|
color_jittering = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomApply(
|
||||||
|
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
||||||
|
p=0.8,
|
||||||
|
),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
global_transfo1_extra = GaussianBlur(p=1.0)
|
||||||
|
|
||||||
|
global_transfo2_extra = transforms.Compose(
|
||||||
|
[
|
||||||
|
GaussianBlur(p=0.1),
|
||||||
|
transforms.RandomSolarize(threshold=128, p=0.2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
local_transfo_extra = GaussianBlur(p=0.5)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
self.normalize = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
make_normalize_transform(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
|
||||||
|
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
|
||||||
|
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
||||||
|
|
||||||
|
def __call__(self, image):
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
# global crops:
|
||||||
|
im1_base = self.geometric_augmentation_global(image)
|
||||||
|
global_crop_1 = self.global_transfo1(im1_base)
|
||||||
|
|
||||||
|
im2_base = self.geometric_augmentation_global(image)
|
||||||
|
global_crop_2 = self.global_transfo2(im2_base)
|
||||||
|
|
||||||
|
output["global_crops"] = [global_crop_1, global_crop_2]
|
||||||
|
|
||||||
|
# global crops for teacher:
|
||||||
|
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
||||||
|
|
||||||
|
# local crops:
|
||||||
|
local_crops = [
|
||||||
|
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
||||||
|
]
|
||||||
|
output["local_crops"] = local_crops
|
||||||
|
output["offsets"] = ()
|
||||||
|
|
||||||
|
return output
|
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
|
||||||
|
# dtype = torch.half # TODO: Remove
|
||||||
|
|
||||||
|
n_global_crops = len(samples_list[0][0]["global_crops"])
|
||||||
|
n_local_crops = len(samples_list[0][0]["local_crops"])
|
||||||
|
|
||||||
|
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
||||||
|
|
||||||
|
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
||||||
|
|
||||||
|
B = len(collated_global_crops)
|
||||||
|
N = n_tokens
|
||||||
|
n_samples_masked = int(B * mask_probability)
|
||||||
|
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
|
||||||
|
upperbound = 0
|
||||||
|
masks_list = []
|
||||||
|
for i in range(0, n_samples_masked):
|
||||||
|
prob_min = probs[i]
|
||||||
|
prob_max = probs[i + 1]
|
||||||
|
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
||||||
|
upperbound += int(N * prob_max)
|
||||||
|
for i in range(n_samples_masked, B):
|
||||||
|
masks_list.append(torch.BoolTensor(mask_generator(0)))
|
||||||
|
|
||||||
|
random.shuffle(masks_list)
|
||||||
|
|
||||||
|
collated_masks = torch.stack(masks_list).flatten(1)
|
||||||
|
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
||||||
|
|
||||||
|
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"collated_global_crops": collated_global_crops.to(dtype),
|
||||||
|
"collated_local_crops": collated_local_crops.to(dtype),
|
||||||
|
"collated_masks": collated_masks,
|
||||||
|
"mask_indices_list": mask_indices_list,
|
||||||
|
"masks_weight": masks_weight,
|
||||||
|
"upperbound": upperbound,
|
||||||
|
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .image_net import ImageNet
|
||||||
|
from .image_net_22k import ImageNet22k
|
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder:
|
||||||
|
def decode(self) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDataDecoder(Decoder):
|
||||||
|
def __init__(self, image_data: bytes) -> None:
|
||||||
|
self._image_data = image_data
|
||||||
|
|
||||||
|
def decode(self) -> Image:
|
||||||
|
f = BytesIO(self._image_data)
|
||||||
|
return Image.open(f).convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
class TargetDecoder(Decoder):
|
||||||
|
def __init__(self, target: Any):
|
||||||
|
self._target = target
|
||||||
|
|
||||||
|
def decode(self) -> Any:
|
||||||
|
return self._target
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
from torchvision.datasets import VisionDataset
|
||||||
|
|
||||||
|
from .decoders import TargetDecoder, ImageDataDecoder
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedVisionDataset(VisionDataset):
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
def get_image_data(self, index: int) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_target(self, index: int) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||||
|
try:
|
||||||
|
image_data = self.get_image_data(index)
|
||||||
|
image = ImageDataDecoder(image_data).decode()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"can not read image for sample {index}") from e
|
||||||
|
target = self.get_target(index)
|
||||||
|
target = TargetDecoder(target).decode()
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
image, target = self.transforms(image, target)
|
||||||
|
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,290 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .extended import ExtendedVisionDataset
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
_Target = int
|
||||||
|
|
||||||
|
|
||||||
|
class _Split(Enum):
|
||||||
|
TRAIN = "train"
|
||||||
|
VAL = "val"
|
||||||
|
TEST = "test" # NOTE: torchvision does not support the test split
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self) -> int:
|
||||||
|
split_lengths = {
|
||||||
|
_Split.TRAIN: 1_281_167,
|
||||||
|
_Split.VAL: 50_000,
|
||||||
|
_Split.TEST: 100_000,
|
||||||
|
}
|
||||||
|
return split_lengths[self]
|
||||||
|
|
||||||
|
def get_dirname(self, class_id: Optional[str] = None) -> str:
|
||||||
|
return self.value if class_id is None else os.path.join(self.value, class_id)
|
||||||
|
|
||||||
|
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
||||||
|
dirname = self.get_dirname(class_id)
|
||||||
|
if self == _Split.TRAIN:
|
||||||
|
basename = f"{class_id}_{actual_index}"
|
||||||
|
else: # self in (_Split.VAL, _Split.TEST):
|
||||||
|
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
||||||
|
return os.path.join(dirname, basename + ".JPEG")
|
||||||
|
|
||||||
|
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
||||||
|
assert self != _Split.TEST
|
||||||
|
dirname, filename = os.path.split(image_relpath)
|
||||||
|
class_id = os.path.split(dirname)[-1]
|
||||||
|
basename, _ = os.path.splitext(filename)
|
||||||
|
actual_index = int(basename.split("_")[-1])
|
||||||
|
return class_id, actual_index
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNet(ExtendedVisionDataset):
|
||||||
|
Target = Union[_Target]
|
||||||
|
Split = Union[_Split]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
split: "ImageNet.Split",
|
||||||
|
root: str,
|
||||||
|
extra: str,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(root, transforms, transform, target_transform)
|
||||||
|
self._extra_root = extra
|
||||||
|
self._split = split
|
||||||
|
|
||||||
|
self._entries = None
|
||||||
|
self._class_ids = None
|
||||||
|
self._class_names = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def split(self) -> "ImageNet.Split":
|
||||||
|
return self._split
|
||||||
|
|
||||||
|
def _get_extra_full_path(self, extra_path: str) -> str:
|
||||||
|
return os.path.join(self._extra_root, extra_path)
|
||||||
|
|
||||||
|
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||||
|
extra_full_path = self._get_extra_full_path(extra_path)
|
||||||
|
return np.load(extra_full_path, mmap_mode="r")
|
||||||
|
|
||||||
|
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||||
|
extra_full_path = self._get_extra_full_path(extra_path)
|
||||||
|
os.makedirs(self._extra_root, exist_ok=True)
|
||||||
|
np.save(extra_full_path, extra_array)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _entries_path(self) -> str:
|
||||||
|
return f"entries-{self._split.value.upper()}.npy"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _class_ids_path(self) -> str:
|
||||||
|
return f"class-ids-{self._split.value.upper()}.npy"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _class_names_path(self) -> str:
|
||||||
|
return f"class-names-{self._split.value.upper()}.npy"
|
||||||
|
|
||||||
|
def _get_entries(self) -> np.ndarray:
|
||||||
|
if self._entries is None:
|
||||||
|
self._entries = self._load_extra(self._entries_path)
|
||||||
|
assert self._entries is not None
|
||||||
|
return self._entries
|
||||||
|
|
||||||
|
def _get_class_ids(self) -> np.ndarray:
|
||||||
|
if self._split == _Split.TEST:
|
||||||
|
assert False, "Class IDs are not available in TEST split"
|
||||||
|
if self._class_ids is None:
|
||||||
|
self._class_ids = self._load_extra(self._class_ids_path)
|
||||||
|
assert self._class_ids is not None
|
||||||
|
return self._class_ids
|
||||||
|
|
||||||
|
def _get_class_names(self) -> np.ndarray:
|
||||||
|
if self._split == _Split.TEST:
|
||||||
|
assert False, "Class names are not available in TEST split"
|
||||||
|
if self._class_names is None:
|
||||||
|
self._class_names = self._load_extra(self._class_names_path)
|
||||||
|
assert self._class_names is not None
|
||||||
|
return self._class_names
|
||||||
|
|
||||||
|
def find_class_id(self, class_index: int) -> str:
|
||||||
|
class_ids = self._get_class_ids()
|
||||||
|
return str(class_ids[class_index])
|
||||||
|
|
||||||
|
def find_class_name(self, class_index: int) -> str:
|
||||||
|
class_names = self._get_class_names()
|
||||||
|
return str(class_names[class_index])
|
||||||
|
|
||||||
|
def get_image_data(self, index: int) -> bytes:
|
||||||
|
entries = self._get_entries()
|
||||||
|
actual_index = entries[index]["actual_index"]
|
||||||
|
|
||||||
|
class_id = self.get_class_id(index)
|
||||||
|
|
||||||
|
image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
||||||
|
image_full_path = os.path.join(self.root, image_relpath)
|
||||||
|
with open(image_full_path, mode="rb") as f:
|
||||||
|
image_data = f.read()
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
def get_target(self, index: int) -> Optional[Target]:
|
||||||
|
entries = self._get_entries()
|
||||||
|
class_index = entries[index]["class_index"]
|
||||||
|
return None if self.split == _Split.TEST else int(class_index)
|
||||||
|
|
||||||
|
def get_targets(self) -> Optional[np.ndarray]:
|
||||||
|
entries = self._get_entries()
|
||||||
|
return None if self.split == _Split.TEST else entries["class_index"]
|
||||||
|
|
||||||
|
def get_class_id(self, index: int) -> Optional[str]:
|
||||||
|
entries = self._get_entries()
|
||||||
|
class_id = entries[index]["class_id"]
|
||||||
|
return None if self.split == _Split.TEST else str(class_id)
|
||||||
|
|
||||||
|
def get_class_name(self, index: int) -> Optional[str]:
|
||||||
|
entries = self._get_entries()
|
||||||
|
class_name = entries[index]["class_name"]
|
||||||
|
return None if self.split == _Split.TEST else str(class_name)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
entries = self._get_entries()
|
||||||
|
assert len(entries) == self.split.length
|
||||||
|
return len(entries)
|
||||||
|
|
||||||
|
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
||||||
|
labels_full_path = os.path.join(self.root, labels_path)
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(labels_full_path, "r") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
for row in reader:
|
||||||
|
class_id, class_name = row
|
||||||
|
labels.append((class_id, class_name))
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def _dump_entries(self) -> None:
|
||||||
|
split = self.split
|
||||||
|
if split == ImageNet.Split.TEST:
|
||||||
|
dataset = None
|
||||||
|
sample_count = split.length
|
||||||
|
max_class_id_length, max_class_name_length = 0, 0
|
||||||
|
else:
|
||||||
|
labels_path = "labels.txt"
|
||||||
|
logger.info(f'loading labels from "{labels_path}"')
|
||||||
|
labels = self._load_labels(labels_path)
|
||||||
|
|
||||||
|
# NOTE: Using torchvision ImageFolder for consistency
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
|
dataset_root = os.path.join(self.root, split.get_dirname())
|
||||||
|
dataset = ImageFolder(dataset_root)
|
||||||
|
sample_count = len(dataset)
|
||||||
|
max_class_id_length, max_class_name_length = -1, -1
|
||||||
|
for sample in dataset.samples:
|
||||||
|
_, class_index = sample
|
||||||
|
class_id, class_name = labels[class_index]
|
||||||
|
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||||
|
max_class_name_length = max(len(class_name), max_class_name_length)
|
||||||
|
|
||||||
|
dtype = np.dtype(
|
||||||
|
[
|
||||||
|
("actual_index", "<u4"),
|
||||||
|
("class_index", "<u4"),
|
||||||
|
("class_id", f"U{max_class_id_length}"),
|
||||||
|
("class_name", f"U{max_class_name_length}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
entries_array = np.empty(sample_count, dtype=dtype)
|
||||||
|
|
||||||
|
if split == ImageNet.Split.TEST:
|
||||||
|
old_percent = -1
|
||||||
|
for index in range(sample_count):
|
||||||
|
percent = 100 * (index + 1) // sample_count
|
||||||
|
if percent > old_percent:
|
||||||
|
logger.info(f"creating entries: {percent}%")
|
||||||
|
old_percent = percent
|
||||||
|
|
||||||
|
actual_index = index + 1
|
||||||
|
class_index = np.uint32(-1)
|
||||||
|
class_id, class_name = "", ""
|
||||||
|
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||||
|
else:
|
||||||
|
class_names = {class_id: class_name for class_id, class_name in labels}
|
||||||
|
|
||||||
|
assert dataset
|
||||||
|
old_percent = -1
|
||||||
|
for index in range(sample_count):
|
||||||
|
percent = 100 * (index + 1) // sample_count
|
||||||
|
if percent > old_percent:
|
||||||
|
logger.info(f"creating entries: {percent}%")
|
||||||
|
old_percent = percent
|
||||||
|
|
||||||
|
image_full_path, class_index = dataset.samples[index]
|
||||||
|
image_relpath = os.path.relpath(image_full_path, self.root)
|
||||||
|
class_id, actual_index = split.parse_image_relpath(image_relpath)
|
||||||
|
class_name = class_names[class_id]
|
||||||
|
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||||
|
|
||||||
|
logger.info(f'saving entries to "{self._entries_path}"')
|
||||||
|
self._save_extra(entries_array, self._entries_path)
|
||||||
|
|
||||||
|
def _dump_class_ids_and_names(self) -> None:
|
||||||
|
split = self.split
|
||||||
|
if split == ImageNet.Split.TEST:
|
||||||
|
return
|
||||||
|
|
||||||
|
entries_array = self._load_extra(self._entries_path)
|
||||||
|
|
||||||
|
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
||||||
|
for entry in entries_array:
|
||||||
|
class_index, class_id, class_name = (
|
||||||
|
entry["class_index"],
|
||||||
|
entry["class_id"],
|
||||||
|
entry["class_name"],
|
||||||
|
)
|
||||||
|
max_class_index = max(int(class_index), max_class_index)
|
||||||
|
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||||
|
max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
||||||
|
|
||||||
|
class_count = max_class_index + 1
|
||||||
|
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
||||||
|
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
||||||
|
for entry in entries_array:
|
||||||
|
class_index, class_id, class_name = (
|
||||||
|
entry["class_index"],
|
||||||
|
entry["class_id"],
|
||||||
|
entry["class_name"],
|
||||||
|
)
|
||||||
|
class_ids_array[class_index] = class_id
|
||||||
|
class_names_array[class_index] = class_name
|
||||||
|
|
||||||
|
logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
||||||
|
self._save_extra(class_ids_array, self._class_ids_path)
|
||||||
|
|
||||||
|
logger.info(f'saving class names to "{self._class_names_path}"')
|
||||||
|
self._save_extra(class_names_array, self._class_names_path)
|
||||||
|
|
||||||
|
def dump_extra(self) -> None:
|
||||||
|
self._dump_entries()
|
||||||
|
self._dump_class_ids_and_names()
|
@ -0,0 +1,302 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
from gzip import GzipFile
|
||||||
|
from io import BytesIO
|
||||||
|
from mmap import ACCESS_READ, mmap
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, List, Optional, Set, Tuple
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .extended import ExtendedVisionDataset
|
||||||
|
|
||||||
|
|
||||||
|
_Labels = int
|
||||||
|
|
||||||
|
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ClassEntry:
|
||||||
|
block_offset: int
|
||||||
|
maybe_filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Entry:
|
||||||
|
class_index: int # noqa: E701
|
||||||
|
start_offset: int
|
||||||
|
end_offset: int
|
||||||
|
filename: str
|
||||||
|
|
||||||
|
|
||||||
|
class _Split(Enum):
|
||||||
|
TRAIN = "train"
|
||||||
|
VAL = "val"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self) -> int:
|
||||||
|
return {
|
||||||
|
_Split.TRAIN: 11_797_647,
|
||||||
|
_Split.VAL: 561_050,
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
def entries_path(self):
|
||||||
|
return f"imagenet21kp_{self.value}.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tarball_path(class_id: str) -> str:
|
||||||
|
return f"{class_id}.tar"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
|
||||||
|
@lru_cache(maxsize=mmap_cache_size)
|
||||||
|
def _mmap_tarball(class_id: str) -> mmap:
|
||||||
|
tarball_path = _get_tarball_path(class_id)
|
||||||
|
tarball_full_path = os.path.join(tarballs_root, tarball_path)
|
||||||
|
with open(tarball_full_path) as f:
|
||||||
|
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
|
||||||
|
|
||||||
|
return _mmap_tarball
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNet22k(ExtendedVisionDataset):
|
||||||
|
_GZIPPED_INDICES: Set[int] = {
|
||||||
|
841_545,
|
||||||
|
1_304_131,
|
||||||
|
2_437_921,
|
||||||
|
2_672_079,
|
||||||
|
2_795_676,
|
||||||
|
2_969_786,
|
||||||
|
6_902_965,
|
||||||
|
6_903_550,
|
||||||
|
6_903_628,
|
||||||
|
7_432_557,
|
||||||
|
7_432_589,
|
||||||
|
7_813_809,
|
||||||
|
8_329_633,
|
||||||
|
10_296_990,
|
||||||
|
10_417_652,
|
||||||
|
10_492_265,
|
||||||
|
10_598_078,
|
||||||
|
10_782_398,
|
||||||
|
10_902_612,
|
||||||
|
11_203_736,
|
||||||
|
11_342_890,
|
||||||
|
11_397_596,
|
||||||
|
11_589_762,
|
||||||
|
11_705_103,
|
||||||
|
12_936_875,
|
||||||
|
13_289_782,
|
||||||
|
}
|
||||||
|
Labels = _Labels
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
root: str,
|
||||||
|
extra: str,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(root, transforms, transform, target_transform)
|
||||||
|
self._extra_root = extra
|
||||||
|
|
||||||
|
entries_path = self._get_entries_path(root)
|
||||||
|
self._entries = self._load_extra(entries_path)
|
||||||
|
|
||||||
|
class_ids_path = self._get_class_ids_path(root)
|
||||||
|
self._class_ids = self._load_extra(class_ids_path)
|
||||||
|
|
||||||
|
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
|
||||||
|
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
|
||||||
|
|
||||||
|
def _get_entries_path(self, root: Optional[str] = None) -> str:
|
||||||
|
return "entries.npy"
|
||||||
|
|
||||||
|
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
|
||||||
|
return "class-ids.npy"
|
||||||
|
|
||||||
|
def _find_class_ids(self, path: str) -> List[str]:
|
||||||
|
class_ids = []
|
||||||
|
|
||||||
|
with os.scandir(path) as entries:
|
||||||
|
for entry in entries:
|
||||||
|
root, ext = os.path.splitext(entry.name)
|
||||||
|
if ext != ".tar":
|
||||||
|
continue
|
||||||
|
class_ids.append(root)
|
||||||
|
|
||||||
|
return sorted(class_ids)
|
||||||
|
|
||||||
|
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
|
||||||
|
root = self.get_root(root)
|
||||||
|
entries: List[_Entry] = []
|
||||||
|
class_ids = self._find_class_ids(root)
|
||||||
|
|
||||||
|
for class_index, class_id in enumerate(class_ids):
|
||||||
|
path = os.path.join(root, "blocks", f"{class_id}.log")
|
||||||
|
class_entries = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.rstrip()
|
||||||
|
block, filename = line.split(":")
|
||||||
|
block_offset = int(block[6:])
|
||||||
|
filename = filename[1:]
|
||||||
|
|
||||||
|
maybe_filename = None
|
||||||
|
if filename != "** Block of NULs **":
|
||||||
|
maybe_filename = filename
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
# assert ext == ".JPEG"
|
||||||
|
|
||||||
|
class_entry = _ClassEntry(block_offset, maybe_filename)
|
||||||
|
class_entries.append(class_entry)
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f'can not read blocks file "{path}"') from e
|
||||||
|
|
||||||
|
assert class_entries[-1].maybe_filename is None
|
||||||
|
|
||||||
|
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
|
||||||
|
assert class_entry1.block_offset <= class_entry2.block_offset
|
||||||
|
start_offset = 512 * class_entry1.block_offset
|
||||||
|
end_offset = 512 * class_entry2.block_offset
|
||||||
|
assert class_entry1.maybe_filename is not None
|
||||||
|
filename = class_entry1.maybe_filename
|
||||||
|
entry = _Entry(class_index, start_offset, end_offset, filename)
|
||||||
|
# Skip invalid image files (PIL throws UnidentifiedImageError)
|
||||||
|
if filename == "n06470073_47249.JPEG":
|
||||||
|
continue
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
return entries, class_ids
|
||||||
|
|
||||||
|
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||||
|
extra_root = self._extra_root
|
||||||
|
extra_full_path = os.path.join(extra_root, extra_path)
|
||||||
|
return np.load(extra_full_path, mmap_mode="r")
|
||||||
|
|
||||||
|
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||||
|
extra_root = self._extra_root
|
||||||
|
extra_full_path = os.path.join(extra_root, extra_path)
|
||||||
|
os.makedirs(extra_root, exist_ok=True)
|
||||||
|
np.save(extra_full_path, extra_array)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _tarballs_root(self) -> str:
|
||||||
|
return self.root
|
||||||
|
|
||||||
|
def find_class_id(self, class_index: int) -> str:
|
||||||
|
return str(self._class_ids[class_index])
|
||||||
|
|
||||||
|
def get_image_data(self, index: int) -> bytes:
|
||||||
|
entry = self._entries[index]
|
||||||
|
class_id = entry["class_id"]
|
||||||
|
class_mmap = self._mmap_tarball(class_id)
|
||||||
|
|
||||||
|
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
|
||||||
|
try:
|
||||||
|
mapped_data = class_mmap[start_offset:end_offset]
|
||||||
|
data = mapped_data[512:] # Skip entry header block
|
||||||
|
|
||||||
|
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
|
||||||
|
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
|
||||||
|
with GzipFile(fileobj=BytesIO(data)) as g:
|
||||||
|
data = g.read()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_target(self, index: int) -> Any:
|
||||||
|
return int(self._entries[index]["class_index"])
|
||||||
|
|
||||||
|
def get_targets(self) -> np.ndarray:
|
||||||
|
return self._entries["class_index"]
|
||||||
|
|
||||||
|
def get_class_id(self, index: int) -> str:
|
||||||
|
return str(self._entries[index]["class_id"])
|
||||||
|
|
||||||
|
def get_class_ids(self) -> np.ndarray:
|
||||||
|
return self._entries["class_id"]
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
return super().__getitem__(index)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
def _dump_entries(self, *args, **kwargs) -> None:
|
||||||
|
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
|
||||||
|
|
||||||
|
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
|
||||||
|
for entry in entries:
|
||||||
|
class_id = class_ids[entry.class_index]
|
||||||
|
max_class_index = max(entry.class_index, max_class_index)
|
||||||
|
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||||
|
max_filename_length = max(len(entry.filename), max_filename_length)
|
||||||
|
|
||||||
|
dtype = np.dtype(
|
||||||
|
[
|
||||||
|
("class_index", "<u4"),
|
||||||
|
("class_id", f"U{max_class_id_length}"),
|
||||||
|
("start_offset", "<u4"),
|
||||||
|
("end_offset", "<u4"),
|
||||||
|
("filename", f"U{max_filename_length}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
sample_count = len(entries)
|
||||||
|
entries_array = np.empty(sample_count, dtype=dtype)
|
||||||
|
for i, entry in enumerate(entries):
|
||||||
|
class_index = entry.class_index
|
||||||
|
class_id = class_ids[class_index]
|
||||||
|
start_offset = entry.start_offset
|
||||||
|
end_offset = entry.end_offset
|
||||||
|
filename = entry.filename
|
||||||
|
entries_array[i] = (
|
||||||
|
class_index,
|
||||||
|
class_id,
|
||||||
|
start_offset,
|
||||||
|
end_offset,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
entries_path = self._get_entries_path(*args, **kwargs)
|
||||||
|
self._save_extra(entries_array, entries_path)
|
||||||
|
|
||||||
|
def _dump_class_ids(self, *args, **kwargs) -> None:
|
||||||
|
entries_path = self._get_entries_path(*args, **kwargs)
|
||||||
|
entries_array = self._load_extra(entries_path)
|
||||||
|
|
||||||
|
max_class_id_length, max_class_index = -1, -1
|
||||||
|
for entry in entries_array:
|
||||||
|
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||||
|
max_class_index = max(int(class_index), max_class_index)
|
||||||
|
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||||
|
|
||||||
|
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
|
||||||
|
for entry in entries_array:
|
||||||
|
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||||
|
class_ids_array[class_index] = class_id
|
||||||
|
class_ids_path = self._get_class_ids_path(*args, **kwargs)
|
||||||
|
self._save_extra(class_ids_array, class_ids_path)
|
||||||
|
|
||||||
|
def _dump_extra(self, *args, **kwargs) -> None:
|
||||||
|
self._dump_entries(*args, *kwargs)
|
||||||
|
self._dump_class_ids(*args, *kwargs)
|
||||||
|
|
||||||
|
def dump_extra(self, root: Optional[str] = None) -> None:
|
||||||
|
return self._dump_extra(root)
|
@ -0,0 +1,222 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
from .datasets import ImageNet, ImageNet22k
|
||||||
|
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerType(Enum):
|
||||||
|
DISTRIBUTED = 0
|
||||||
|
EPOCH = 1
|
||||||
|
INFINITE = 2
|
||||||
|
SHARDED_INFINITE = 3
|
||||||
|
SHARDED_INFINITE_NEW = 4
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bool_str(b: bool) -> str:
|
||||||
|
return "yes" if b else "no"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
||||||
|
def transform(sample):
|
||||||
|
image, target = sample
|
||||||
|
if image_transform is not None:
|
||||||
|
image = image_transform(image)
|
||||||
|
if target_transform is not None:
|
||||||
|
target = target_transform(target)
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
return transform
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_dataset_str(dataset_str: str):
|
||||||
|
tokens = dataset_str.split(":")
|
||||||
|
|
||||||
|
name = tokens[0]
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
for token in tokens[1:]:
|
||||||
|
key, value = token.split("=")
|
||||||
|
assert key in ("root", "extra", "split")
|
||||||
|
kwargs[key] = value
|
||||||
|
|
||||||
|
if name == "ImageNet":
|
||||||
|
class_ = ImageNet
|
||||||
|
if "split" in kwargs:
|
||||||
|
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||||
|
elif name == "ImageNet22k":
|
||||||
|
class_ = ImageNet22k
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported dataset "{name}"')
|
||||||
|
|
||||||
|
return class_, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataset(
|
||||||
|
*,
|
||||||
|
dataset_str: str,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a dataset with the specified parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
||||||
|
transform: A transform to apply to images.
|
||||||
|
target_transform: A transform to apply to targets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created dataset.
|
||||||
|
"""
|
||||||
|
logger.info(f'using dataset: "{dataset_str}"')
|
||||||
|
|
||||||
|
class_, kwargs = _parse_dataset_str(dataset_str)
|
||||||
|
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
||||||
|
|
||||||
|
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
||||||
|
|
||||||
|
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
||||||
|
if not hasattr(dataset, "transform"):
|
||||||
|
setattr(dataset, "transform", transform)
|
||||||
|
if not hasattr(dataset, "target_transform"):
|
||||||
|
setattr(dataset, "target_transform", target_transform)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sampler(
|
||||||
|
*,
|
||||||
|
dataset,
|
||||||
|
type: Optional[SamplerType] = None,
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
size: int = -1,
|
||||||
|
advance: int = 0,
|
||||||
|
) -> Optional[Sampler]:
|
||||||
|
sample_count = len(dataset)
|
||||||
|
|
||||||
|
if type == SamplerType.INFINITE:
|
||||||
|
logger.info("sampler: infinite")
|
||||||
|
if size > 0:
|
||||||
|
raise ValueError("sampler size > 0 is invalid")
|
||||||
|
return InfiniteSampler(
|
||||||
|
sample_count=sample_count,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
advance=advance,
|
||||||
|
)
|
||||||
|
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
||||||
|
logger.info("sampler: sharded infinite")
|
||||||
|
if size > 0:
|
||||||
|
raise ValueError("sampler size > 0 is invalid")
|
||||||
|
# TODO: Remove support for old shuffling
|
||||||
|
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
||||||
|
return ShardedInfiniteSampler(
|
||||||
|
sample_count=sample_count,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
advance=advance,
|
||||||
|
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
||||||
|
)
|
||||||
|
elif type == SamplerType.EPOCH:
|
||||||
|
logger.info("sampler: epoch")
|
||||||
|
if advance > 0:
|
||||||
|
raise NotImplementedError("sampler advance > 0 is not supported")
|
||||||
|
size = size if size > 0 else sample_count
|
||||||
|
logger.info(f"# of samples / epoch: {size:,d}")
|
||||||
|
return EpochSampler(
|
||||||
|
size=size,
|
||||||
|
sample_count=sample_count,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
elif type == SamplerType.DISTRIBUTED:
|
||||||
|
logger.info("sampler: distributed")
|
||||||
|
if size > 0:
|
||||||
|
raise ValueError("sampler size > 0 is invalid")
|
||||||
|
if advance > 0:
|
||||||
|
raise ValueError("sampler advance > 0 is invalid")
|
||||||
|
return torch.utils.data.DistributedSampler(
|
||||||
|
dataset=dataset,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("sampler: none")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def make_data_loader(
|
||||||
|
*,
|
||||||
|
dataset,
|
||||||
|
batch_size: int,
|
||||||
|
num_workers: int,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: int = 0,
|
||||||
|
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
||||||
|
sampler_size: int = -1,
|
||||||
|
sampler_advance: int = 0,
|
||||||
|
drop_last: bool = True,
|
||||||
|
persistent_workers: bool = False,
|
||||||
|
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a data loader with the specified parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A dataset (third party, LaViDa or WebDataset).
|
||||||
|
batch_size: The size of batches to generate.
|
||||||
|
num_workers: The number of workers to use.
|
||||||
|
shuffle: Whether to shuffle samples.
|
||||||
|
seed: The random seed to use.
|
||||||
|
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
||||||
|
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
||||||
|
sampler_advance: How many samples to skip (when applicable).
|
||||||
|
drop_last: Whether the last non-full batch of data should be dropped.
|
||||||
|
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
||||||
|
collate_fn: Function that performs batch collation
|
||||||
|
"""
|
||||||
|
|
||||||
|
sampler = _make_sampler(
|
||||||
|
dataset=dataset,
|
||||||
|
type=sampler_type,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
size=sampler_size,
|
||||||
|
advance=sampler_advance,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("using PyTorch data loader")
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=drop_last,
|
||||||
|
persistent_workers=persistent_workers,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"# of batches: {len(data_loader):,d}")
|
||||||
|
except TypeError: # data loader has no length
|
||||||
|
logger.info("infinite data loader")
|
||||||
|
return data_loader
|
@ -0,0 +1,86 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MaskingGenerator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
num_masking_patches=None,
|
||||||
|
min_num_patches=4,
|
||||||
|
max_num_patches=None,
|
||||||
|
min_aspect=0.3,
|
||||||
|
max_aspect=None,
|
||||||
|
):
|
||||||
|
if not isinstance(input_size, tuple):
|
||||||
|
input_size = (input_size,) * 2
|
||||||
|
self.height, self.width = input_size
|
||||||
|
|
||||||
|
self.num_patches = self.height * self.width
|
||||||
|
self.num_masking_patches = num_masking_patches
|
||||||
|
|
||||||
|
self.min_num_patches = min_num_patches
|
||||||
|
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
||||||
|
|
||||||
|
max_aspect = max_aspect or 1 / min_aspect
|
||||||
|
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||||
|
self.height,
|
||||||
|
self.width,
|
||||||
|
self.min_num_patches,
|
||||||
|
self.max_num_patches,
|
||||||
|
self.num_masking_patches,
|
||||||
|
self.log_aspect_ratio[0],
|
||||||
|
self.log_aspect_ratio[1],
|
||||||
|
)
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
return self.height, self.width
|
||||||
|
|
||||||
|
def _mask(self, mask, max_mask_patches):
|
||||||
|
delta = 0
|
||||||
|
for _ in range(10):
|
||||||
|
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||||
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if w < self.width and h < self.height:
|
||||||
|
top = random.randint(0, self.height - h)
|
||||||
|
left = random.randint(0, self.width - w)
|
||||||
|
|
||||||
|
num_masked = mask[top : top + h, left : left + w].sum()
|
||||||
|
# Overlap
|
||||||
|
if 0 < h * w - num_masked <= max_mask_patches:
|
||||||
|
for i in range(top, top + h):
|
||||||
|
for j in range(left, left + w):
|
||||||
|
if mask[i, j] == 0:
|
||||||
|
mask[i, j] = 1
|
||||||
|
delta += 1
|
||||||
|
|
||||||
|
if delta > 0:
|
||||||
|
break
|
||||||
|
return delta
|
||||||
|
|
||||||
|
def __call__(self, num_masking_patches=0):
|
||||||
|
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
||||||
|
mask_count = 0
|
||||||
|
while mask_count < num_masking_patches:
|
||||||
|
max_mask_patches = num_masking_patches - mask_count
|
||||||
|
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||||
|
|
||||||
|
delta = self._mask(mask, max_mask_patches)
|
||||||
|
if delta == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
mask_count += delta
|
||||||
|
|
||||||
|
return mask
|
@ -0,0 +1,229 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from typing import Any, Optional
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
|
import dinov2.distributed as distributed
|
||||||
|
|
||||||
|
|
||||||
|
class EpochSampler(Sampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
size: int,
|
||||||
|
sample_count: int,
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self._size = size
|
||||||
|
self._sample_count = sample_count
|
||||||
|
self._shuffle = shuffle
|
||||||
|
self._seed = seed
|
||||||
|
self._start = distributed.get_global_rank() if start is None else start
|
||||||
|
self._step = distributed.get_global_size() if step is None else step
|
||||||
|
self._epoch = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
count = (self._size + self._sample_count - 1) // self._sample_count
|
||||||
|
tiled_indices = np.tile(np.arange(self._sample_count), count)
|
||||||
|
if self._shuffle:
|
||||||
|
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
iterable = rng.choice(tiled_indices, self._size, replace=False)
|
||||||
|
else:
|
||||||
|
iterable = tiled_indices[: self._size]
|
||||||
|
|
||||||
|
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return (self._size - self._start + self._step - 1) // self._step
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
self._epoch = epoch
|
||||||
|
|
||||||
|
|
||||||
|
def _get_numpy_dtype(size: int) -> Any:
|
||||||
|
return np.int32 if size <= 2**31 else np.int64
|
||||||
|
|
||||||
|
|
||||||
|
def _get_torch_dtype(size: int) -> Any:
|
||||||
|
return torch.int32 if size <= 2**31 else torch.int64
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
|
||||||
|
"""Generate the indices of a random permutation."""
|
||||||
|
dtype = _get_torch_dtype(size)
|
||||||
|
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
|
||||||
|
perm = torch.arange(size, dtype=dtype)
|
||||||
|
for i in range(size):
|
||||||
|
j = torch.randint(i, size, size=(1,), generator=generator).item()
|
||||||
|
|
||||||
|
# Always swap even if no-op
|
||||||
|
value = perm[j].item()
|
||||||
|
perm[j] = perm[i].item()
|
||||||
|
perm[i] = value
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteSampler(Sampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sample_count: int,
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
advance: int = 0,
|
||||||
|
):
|
||||||
|
self._sample_count = sample_count
|
||||||
|
self._seed = seed
|
||||||
|
self._shuffle = shuffle
|
||||||
|
self._start = distributed.get_global_rank() if start is None else start
|
||||||
|
self._step = distributed.get_global_size() if step is None else step
|
||||||
|
self._advance = advance
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self._shuffle:
|
||||||
|
iterator = self._shuffled_iterator()
|
||||||
|
else:
|
||||||
|
iterator = self._iterator()
|
||||||
|
|
||||||
|
yield from itertools.islice(iterator, self._advance, None)
|
||||||
|
|
||||||
|
def _iterator(self):
|
||||||
|
assert not self._shuffle
|
||||||
|
|
||||||
|
while True:
|
||||||
|
iterable = range(self._sample_count)
|
||||||
|
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||||
|
|
||||||
|
def _shuffled_iterator(self):
|
||||||
|
assert self._shuffle
|
||||||
|
|
||||||
|
# Instantiate a generator here (rather than in the ctor) to keep the class
|
||||||
|
# picklable (requirement of mp.spawn)
|
||||||
|
generator = torch.Generator().manual_seed(self._seed)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
|
||||||
|
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||||
|
|
||||||
|
|
||||||
|
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
|
||||||
|
# but avoids a full in-place random permutation generation.
|
||||||
|
def _shuffle_tensor_slice(
|
||||||
|
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||||
|
) -> np.ndarray:
|
||||||
|
stop = len(tensor)
|
||||||
|
count = stop // step
|
||||||
|
drop_count = stop - step * count
|
||||||
|
if drop_count:
|
||||||
|
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||||
|
|
||||||
|
dtype = _get_numpy_dtype(stop)
|
||||||
|
result = np.empty(count, dtype=dtype)
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
|
||||||
|
|
||||||
|
result[i] = result[j]
|
||||||
|
result[j] = tensor[start + i * step].item()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _new_shuffle_tensor_slice(
|
||||||
|
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||||
|
) -> np.ndarray:
|
||||||
|
stop = len(tensor)
|
||||||
|
count = stop // step
|
||||||
|
dtype = torch.int64 # Needed for using randperm result as indices
|
||||||
|
count = stop // step
|
||||||
|
drop_count = stop - step * count
|
||||||
|
if drop_count:
|
||||||
|
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||||
|
indices = torch.randperm(count, dtype=dtype, generator=generator)
|
||||||
|
return tensor[start::step][indices].numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seed(seed: int, start: int, iter_count: int) -> int:
|
||||||
|
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
|
||||||
|
return seed + start + (iter_count << 24)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedInfiniteSampler(Sampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sample_count: int,
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
advance: int = 0,
|
||||||
|
use_new_shuffle_tensor_slice: bool = False,
|
||||||
|
):
|
||||||
|
self._sample_count = sample_count
|
||||||
|
self._seed = seed
|
||||||
|
self._shuffle = shuffle
|
||||||
|
self._start = distributed.get_global_rank() if start is None else start
|
||||||
|
self._step = distributed.get_global_size() if step is None else step
|
||||||
|
self._advance = advance
|
||||||
|
self._iter_count = 0
|
||||||
|
self._shuffle_tensor_slice_fn = (
|
||||||
|
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
iter_count = self._advance // self._sample_count
|
||||||
|
if iter_count > 0:
|
||||||
|
self._advance -= iter_count * self._sample_count
|
||||||
|
self._iter_count += iter_count
|
||||||
|
|
||||||
|
if self._shuffle:
|
||||||
|
iterator = self._shuffled_iterator()
|
||||||
|
else:
|
||||||
|
iterator = self._iterator()
|
||||||
|
|
||||||
|
yield from itertools.islice(iterator, self._advance, None)
|
||||||
|
|
||||||
|
def _iterator(self):
|
||||||
|
assert not self._shuffle
|
||||||
|
|
||||||
|
while True:
|
||||||
|
iterable = range(self._sample_count)
|
||||||
|
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||||
|
|
||||||
|
def _shuffled_iterator(self):
|
||||||
|
assert self._shuffle
|
||||||
|
|
||||||
|
# Instantiate a generator here (rather than in the ctor) to be keep the class
|
||||||
|
# picklable (requirement of mp.spawn)
|
||||||
|
generator = torch.Generator()
|
||||||
|
|
||||||
|
# Always shuffle everything first
|
||||||
|
generator.manual_seed(self._seed)
|
||||||
|
dtype = _get_torch_dtype(self._sample_count)
|
||||||
|
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Re-seed on each iteration to allow skipping whole permutations
|
||||||
|
seed = _make_seed(self._seed, self._start, self._iter_count)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
iterable = self._shuffle_tensor_slice_fn(
|
||||||
|
tensor=perm, start=self._start, step=self._step, generator=generator
|
||||||
|
)
|
||||||
|
yield from iterable
|
||||||
|
self._iter_count += 1
|
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(transforms.RandomApply):
|
||||||
|
"""
|
||||||
|
Apply Gaussian Blur to the PIL image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
||||||
|
# NOTE: torchvision is applying 1 - probability to return the original image
|
||||||
|
keep_p = 1 - p
|
||||||
|
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
||||||
|
super().__init__(transforms=[transform], p=keep_p)
|
||||||
|
|
||||||
|
|
||||||
|
class MaybeToTensor(transforms.ToTensor):
|
||||||
|
"""
|
||||||
|
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, pic):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor: Converted image.
|
||||||
|
"""
|
||||||
|
if isinstance(pic, torch.Tensor):
|
||||||
|
return pic
|
||||||
|
return super().__call__(pic)
|
||||||
|
|
||||||
|
|
||||||
|
# Use timm's names
|
||||||
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||||
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||||
|
|
||||||
|
|
||||||
|
def make_normalize_transform(
|
||||||
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||||
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||||
|
) -> transforms.Normalize:
|
||||||
|
return transforms.Normalize(mean=mean, std=std)
|
||||||
|
|
||||||
|
|
||||||
|
# This roughly matches torchvision's preset for classification training:
|
||||||
|
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
||||||
|
def make_classification_train_transform(
|
||||||
|
*,
|
||||||
|
crop_size: int = 224,
|
||||||
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||||
|
hflip_prob: float = 0.5,
|
||||||
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||||
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||||
|
):
|
||||||
|
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
||||||
|
if hflip_prob > 0.0:
|
||||||
|
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
||||||
|
transforms_list.extend(
|
||||||
|
[
|
||||||
|
MaybeToTensor(),
|
||||||
|
make_normalize_transform(mean=mean, std=std),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return transforms.Compose(transforms_list)
|
||||||
|
|
||||||
|
|
||||||
|
# This matches (roughly) torchvision's preset for classification evaluation:
|
||||||
|
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
||||||
|
def make_classification_eval_transform(
|
||||||
|
*,
|
||||||
|
resize_size: int = 256,
|
||||||
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||||
|
crop_size: int = 224,
|
||||||
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||||
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||||
|
) -> transforms.Compose:
|
||||||
|
transforms_list = [
|
||||||
|
transforms.Resize(resize_size, interpolation=interpolation),
|
||||||
|
transforms.CenterCrop(crop_size),
|
||||||
|
MaybeToTensor(),
|
||||||
|
make_normalize_transform(mean=mean, std=std),
|
||||||
|
]
|
||||||
|
return transforms.Compose(transforms_list)
|
@ -0,0 +1,270 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import socket
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
_LOCAL_RANK = -1
|
||||||
|
_LOCAL_WORLD_SIZE = -1
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled() -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
True if distributed training is enabled
|
||||||
|
"""
|
||||||
|
return dist.is_available() and dist.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_size() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The number of processes in the process group
|
||||||
|
"""
|
||||||
|
return dist.get_world_size() if is_enabled() else 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_rank() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The rank of the current process within the global process group.
|
||||||
|
"""
|
||||||
|
return dist.get_rank() if is_enabled() else 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The rank of the current process within the local (per-machine) process group.
|
||||||
|
"""
|
||||||
|
if not is_enabled():
|
||||||
|
return 0
|
||||||
|
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||||
|
return _LOCAL_RANK
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_size() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The size of the per-machine process group,
|
||||||
|
i.e. the number of processes per machine.
|
||||||
|
"""
|
||||||
|
if not is_enabled():
|
||||||
|
return 1
|
||||||
|
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||||
|
return _LOCAL_WORLD_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process() -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
True if the current process is the main one.
|
||||||
|
"""
|
||||||
|
return get_global_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _restrict_print_to_main_process() -> None:
|
||||||
|
"""
|
||||||
|
This function disables printing when not in the main process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_main_process() or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def _get_master_port(seed: int = 0) -> int:
|
||||||
|
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
||||||
|
|
||||||
|
master_port_str = os.environ.get("MASTER_PORT")
|
||||||
|
if master_port_str is None:
|
||||||
|
rng = random.Random(seed)
|
||||||
|
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
||||||
|
|
||||||
|
return int(master_port_str)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
|
||||||
|
# Note this is not compatible with IPv6.
|
||||||
|
s.bind(("", 0))
|
||||||
|
port = s.getsockname()[1]
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
_TORCH_DISTRIBUTED_ENV_VARS = (
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_WORLD_SIZE",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_env_vars() -> Dict[str, str]:
|
||||||
|
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_slurm_job_process() -> bool:
|
||||||
|
return "SLURM_JOB_ID" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_slurm_node_list(s: str) -> List[str]:
|
||||||
|
nodes = []
|
||||||
|
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
|
||||||
|
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
|
||||||
|
for m in p.finditer(s):
|
||||||
|
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
|
||||||
|
for suffix in suffixes.split(","):
|
||||||
|
span = suffix.split("-")
|
||||||
|
if len(span) == 1:
|
||||||
|
nodes.append(prefix + suffix)
|
||||||
|
else:
|
||||||
|
width = len(span[0])
|
||||||
|
start, end = int(span[0]), int(span[1]) + 1
|
||||||
|
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _check_env_variable(key: str, new_value: str):
|
||||||
|
# Only check for difference with preset environment variables
|
||||||
|
if key in os.environ and os.environ[key] != new_value:
|
||||||
|
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
|
||||||
|
|
||||||
|
|
||||||
|
class _TorchDistributedEnvironment:
|
||||||
|
def __init__(self):
|
||||||
|
self.master_addr = "127.0.0.1"
|
||||||
|
self.master_port = 0
|
||||||
|
self.rank = -1
|
||||||
|
self.world_size = -1
|
||||||
|
self.local_rank = -1
|
||||||
|
self.local_world_size = -1
|
||||||
|
|
||||||
|
if _is_slurm_job_process():
|
||||||
|
return self._set_from_slurm_env()
|
||||||
|
|
||||||
|
env_vars = _collect_env_vars()
|
||||||
|
if not env_vars:
|
||||||
|
# Environment is not set
|
||||||
|
pass
|
||||||
|
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
||||||
|
# Environment is fully set
|
||||||
|
return self._set_from_preset_env()
|
||||||
|
else:
|
||||||
|
# Environment is partially set
|
||||||
|
collected_env_vars = ", ".join(env_vars.keys())
|
||||||
|
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 0:
|
||||||
|
return self._set_from_local()
|
||||||
|
|
||||||
|
raise RuntimeError("Can't initialize PyTorch distributed environment")
|
||||||
|
|
||||||
|
# Slurm job created with sbatch, submitit, etc...
|
||||||
|
def _set_from_slurm_env(self):
|
||||||
|
# logger.info("Initialization from Slurm environment")
|
||||||
|
job_id = int(os.environ["SLURM_JOB_ID"])
|
||||||
|
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
||||||
|
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
||||||
|
assert len(nodes) == node_count
|
||||||
|
|
||||||
|
self.master_addr = nodes[0]
|
||||||
|
self.master_port = _get_master_port(seed=job_id)
|
||||||
|
self.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
self.world_size = int(os.environ["SLURM_NTASKS"])
|
||||||
|
assert self.rank < self.world_size
|
||||||
|
self.local_rank = int(os.environ["SLURM_LOCALID"])
|
||||||
|
self.local_world_size = self.world_size // node_count
|
||||||
|
assert self.local_rank < self.local_world_size
|
||||||
|
|
||||||
|
# Single node job with preset environment (i.e. torchrun)
|
||||||
|
def _set_from_preset_env(self):
|
||||||
|
# logger.info("Initialization from preset environment")
|
||||||
|
self.master_addr = os.environ["MASTER_ADDR"]
|
||||||
|
self.master_port = os.environ["MASTER_PORT"]
|
||||||
|
self.rank = int(os.environ["RANK"])
|
||||||
|
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
assert self.rank < self.world_size
|
||||||
|
self.local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||||
|
assert self.local_rank < self.local_world_size
|
||||||
|
|
||||||
|
# Single node and GPU job (i.e. local script run)
|
||||||
|
def _set_from_local(self):
|
||||||
|
# logger.info("Initialization from local")
|
||||||
|
self.master_addr = "127.0.0.1"
|
||||||
|
self.master_port = _get_available_port()
|
||||||
|
self.rank = 0
|
||||||
|
self.world_size = 1
|
||||||
|
self.local_rank = 0
|
||||||
|
self.local_world_size = 1
|
||||||
|
|
||||||
|
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
|
||||||
|
# See the "Environment variable initialization" section from
|
||||||
|
# https://pytorch.org/docs/stable/distributed.html for the complete list of
|
||||||
|
# environment variables required for the env:// initialization method.
|
||||||
|
env_vars = {
|
||||||
|
"MASTER_ADDR": self.master_addr,
|
||||||
|
"MASTER_PORT": str(self.master_port),
|
||||||
|
"RANK": str(self.rank),
|
||||||
|
"WORLD_SIZE": str(self.world_size),
|
||||||
|
"LOCAL_RANK": str(self.local_rank),
|
||||||
|
"LOCAL_WORLD_SIZE": str(self.local_world_size),
|
||||||
|
}
|
||||||
|
if not overwrite:
|
||||||
|
for k, v in env_vars.items():
|
||||||
|
_check_env_variable(k, v)
|
||||||
|
|
||||||
|
os.environ.update(env_vars)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
|
||||||
|
"""Enable distributed mode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
||||||
|
current PyTorch CUDA device to the one matching the local rank.
|
||||||
|
overwrite: If True, overwrites already set variables. Else fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
||||||
|
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
|
||||||
|
raise RuntimeError("Distributed mode has already been enabled")
|
||||||
|
torch_env = _TorchDistributedEnvironment()
|
||||||
|
torch_env.export(overwrite=overwrite)
|
||||||
|
|
||||||
|
if set_cuda_current_device:
|
||||||
|
torch.cuda.set_device(torch_env.local_rank)
|
||||||
|
|
||||||
|
if allow_nccl_timeout:
|
||||||
|
# This allows to use torch distributed timeout in a NCCL backend
|
||||||
|
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
|
||||||
|
if not overwrite:
|
||||||
|
_check_env_variable(key, value)
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Finalize setup
|
||||||
|
_LOCAL_RANK = torch_env.local_rank
|
||||||
|
_LOCAL_WORLD_SIZE = torch_env.local_world_size
|
||||||
|
_restrict_print_to_main_process()
|
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .backbones import * # noqa: F403
|
||||||
|
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
|
||||||
|
from .decode_heads import * # noqa: F403
|
||||||
|
from .depther import * # noqa: F403
|
||||||
|
from .losses import * # noqa: F403
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .vision_transformer import DinoVisionTransformer
|
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class DinoVisionTransformer(BaseModule):
|
||||||
|
"""Vision Transformer."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||||
|
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
||||||
|
from mmcv.utils import Registry
|
||||||
|
|
||||||
|
MODELS = Registry("models", parent=MMCV_MODELS)
|
||||||
|
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
|
||||||
|
|
||||||
|
|
||||||
|
BACKBONES = MODELS
|
||||||
|
NECKS = MODELS
|
||||||
|
HEADS = MODELS
|
||||||
|
LOSSES = MODELS
|
||||||
|
DEPTHER = MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(cfg):
|
||||||
|
"""Build backbone."""
|
||||||
|
return BACKBONES.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_neck(cfg):
|
||||||
|
"""Build neck."""
|
||||||
|
return NECKS.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_head(cfg):
|
||||||
|
"""Build head."""
|
||||||
|
return HEADS.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_loss(cfg):
|
||||||
|
"""Build loss."""
|
||||||
|
return LOSSES.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_depther(cfg, train_cfg=None, test_cfg=None):
|
||||||
|
"""Build depther."""
|
||||||
|
if train_cfg is not None or test_cfg is not None:
|
||||||
|
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
|
||||||
|
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
|
||||||
|
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
|
||||||
|
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dpt_head import DPTHead
|
||||||
|
from .linear_head import BNHead
|
@ -0,0 +1,225 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
||||||
|
|
||||||
|
from ...ops import resize
|
||||||
|
from ..builder import build_loss
|
||||||
|
|
||||||
|
|
||||||
|
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||||
|
"""Base class for BaseDecodeHead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (List): Input channels.
|
||||||
|
channels (int): Channels after modules, before conv_depth.
|
||||||
|
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||||
|
act_cfg (dict): Config of activation layers.
|
||||||
|
Default: dict(type='ReLU')
|
||||||
|
loss_decode (dict): Config of decode loss.
|
||||||
|
Default: dict(type='SigLoss').
|
||||||
|
sampler (dict|None): The config of depth map sampler.
|
||||||
|
Default: None.
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False.
|
||||||
|
min_depth (int): Min depth in dataset setting.
|
||||||
|
Default: 1e-3.
|
||||||
|
max_depth (int): Max depth in dataset setting.
|
||||||
|
Default: None.
|
||||||
|
norm_cfg (dict|None): Config of norm layers.
|
||||||
|
Default: None.
|
||||||
|
classify (bool): Whether predict depth in a cls.-reg. manner.
|
||||||
|
Default: False.
|
||||||
|
n_bins (int): The number of bins used in cls. step.
|
||||||
|
Default: 256.
|
||||||
|
bins_strategy (str): The discrete strategy used in cls. step.
|
||||||
|
Default: 'UD'.
|
||||||
|
norm_strategy (str): The norm strategy on cls. probability
|
||||||
|
distribution. Default: 'linear'
|
||||||
|
scale_up (str): Whether predict depth in a scale-up manner.
|
||||||
|
Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
channels=96,
|
||||||
|
conv_cfg=None,
|
||||||
|
act_cfg=dict(type="ReLU"),
|
||||||
|
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
|
||||||
|
sampler=None,
|
||||||
|
align_corners=False,
|
||||||
|
min_depth=1e-3,
|
||||||
|
max_depth=None,
|
||||||
|
norm_cfg=None,
|
||||||
|
classify=False,
|
||||||
|
n_bins=256,
|
||||||
|
bins_strategy="UD",
|
||||||
|
norm_strategy="linear",
|
||||||
|
scale_up=False,
|
||||||
|
):
|
||||||
|
super(DepthBaseDecodeHead, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.channels = channels
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
if isinstance(loss_decode, dict):
|
||||||
|
self.loss_decode = build_loss(loss_decode)
|
||||||
|
elif isinstance(loss_decode, (list, tuple)):
|
||||||
|
self.loss_decode = nn.ModuleList()
|
||||||
|
for loss in loss_decode:
|
||||||
|
self.loss_decode.append(build_loss(loss))
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.min_depth = min_depth
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.classify = classify
|
||||||
|
self.n_bins = n_bins
|
||||||
|
self.scale_up = scale_up
|
||||||
|
|
||||||
|
if self.classify:
|
||||||
|
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
||||||
|
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
||||||
|
|
||||||
|
self.bins_strategy = bins_strategy
|
||||||
|
self.norm_strategy = norm_strategy
|
||||||
|
self.softmax = nn.Softmax(dim=1)
|
||||||
|
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
||||||
|
else:
|
||||||
|
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
||||||
|
|
||||||
|
self.fp16_enabled = False
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
"""Extra repr."""
|
||||||
|
s = f"align_corners={self.align_corners}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
@auto_fp16()
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, inputs, img_metas):
|
||||||
|
"""Placeholder of forward function."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
|
||||||
|
"""Forward function for training.
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
img_metas (list[dict]): List of image info dict where each dict
|
||||||
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||||
|
For details on the values of these keys see
|
||||||
|
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||||
|
depth_gt (Tensor): GT depth
|
||||||
|
train_cfg (dict): The training config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Tensor]: a dictionary of loss components
|
||||||
|
"""
|
||||||
|
depth_pred = self.forward(inputs, img_metas)
|
||||||
|
losses = self.losses(depth_pred, depth_gt)
|
||||||
|
|
||||||
|
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
||||||
|
losses.update(**log_imgs)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def forward_test(self, inputs, img_metas, test_cfg):
|
||||||
|
"""Forward function for testing.
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
img_metas (list[dict]): List of image info dict where each dict
|
||||||
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||||
|
For details on the values of these keys see
|
||||||
|
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||||
|
test_cfg (dict): The testing config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output depth map.
|
||||||
|
"""
|
||||||
|
return self.forward(inputs, img_metas)
|
||||||
|
|
||||||
|
def depth_pred(self, feat):
|
||||||
|
"""Prediction each pixel."""
|
||||||
|
if self.classify:
|
||||||
|
logit = self.conv_depth(feat)
|
||||||
|
|
||||||
|
if self.bins_strategy == "UD":
|
||||||
|
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
||||||
|
elif self.bins_strategy == "SID":
|
||||||
|
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
||||||
|
|
||||||
|
# following Adabins, default linear
|
||||||
|
if self.norm_strategy == "linear":
|
||||||
|
logit = torch.relu(logit)
|
||||||
|
eps = 0.1
|
||||||
|
logit = logit + eps
|
||||||
|
logit = logit / logit.sum(dim=1, keepdim=True)
|
||||||
|
elif self.norm_strategy == "softmax":
|
||||||
|
logit = torch.softmax(logit, dim=1)
|
||||||
|
elif self.norm_strategy == "sigmoid":
|
||||||
|
logit = torch.sigmoid(logit)
|
||||||
|
logit = logit / logit.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.scale_up:
|
||||||
|
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
||||||
|
else:
|
||||||
|
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
||||||
|
return output
|
||||||
|
|
||||||
|
@force_fp32(apply_to=("depth_pred",))
|
||||||
|
def losses(self, depth_pred, depth_gt):
|
||||||
|
"""Compute depth loss."""
|
||||||
|
loss = dict()
|
||||||
|
depth_pred = resize(
|
||||||
|
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
||||||
|
)
|
||||||
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||||
|
losses_decode = [self.loss_decode]
|
||||||
|
else:
|
||||||
|
losses_decode = self.loss_decode
|
||||||
|
for loss_decode in losses_decode:
|
||||||
|
if loss_decode.loss_name not in loss:
|
||||||
|
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
||||||
|
else:
|
||||||
|
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
||||||
|
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
||||||
|
show_img = show_img.numpy().astype(np.float32)
|
||||||
|
show_img = mmcv.imdenormalize(
|
||||||
|
show_img,
|
||||||
|
img_meta["img_norm_cfg"]["mean"],
|
||||||
|
img_meta["img_norm_cfg"]["std"],
|
||||||
|
img_meta["img_norm_cfg"]["to_rgb"],
|
||||||
|
)
|
||||||
|
show_img = np.clip(show_img, 0, 255)
|
||||||
|
show_img = show_img.astype(np.uint8)
|
||||||
|
show_img = show_img[:, :, ::-1]
|
||||||
|
show_img = show_img.transpose(0, 2, 1)
|
||||||
|
show_img = show_img.transpose(1, 0, 2)
|
||||||
|
|
||||||
|
depth_pred = depth_pred / torch.max(depth_pred)
|
||||||
|
depth_gt = depth_gt / torch.max(depth_gt)
|
||||||
|
|
||||||
|
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
||||||
|
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
||||||
|
|
||||||
|
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
@ -0,0 +1,270 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||||
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
|
from ...ops import resize
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .decode_head import DepthBaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(nn.Module):
|
||||||
|
def __init__(self, scale_factor, mode, align_corners=False):
|
||||||
|
super(Interpolate, self).__init__()
|
||||||
|
self.interp = nn.functional.interpolate
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HeadDepth(nn.Module):
|
||||||
|
def __init__(self, features):
|
||||||
|
super(HeadDepth, self).__init__()
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||||
|
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||||
|
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ReassembleBlocks(BaseModule):
|
||||||
|
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||||
|
rearrange the feature vector to feature map.
|
||||||
|
Args:
|
||||||
|
in_channels (int): ViT feature channels. Default: 768.
|
||||||
|
out_channels (List): output channels of each stage.
|
||||||
|
Default: [96, 192, 384, 768].
|
||||||
|
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||||
|
patch_size (int): The patch size. Default: 16.
|
||||||
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
|
||||||
|
):
|
||||||
|
super(ReassembleBlocks, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
assert readout_type in ["ignore", "add", "project"]
|
||||||
|
self.readout_type = readout_type
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.projects = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConvModule(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channel,
|
||||||
|
kernel_size=1,
|
||||||
|
act_cfg=None,
|
||||||
|
)
|
||||||
|
for out_channel in out_channels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resize_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||||
|
),
|
||||||
|
nn.Identity(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.readout_type == "project":
|
||||||
|
self.readout_projects = nn.ModuleList()
|
||||||
|
for _ in range(len(self.projects)):
|
||||||
|
self.readout_projects.append(
|
||||||
|
nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
assert isinstance(inputs, list)
|
||||||
|
out = []
|
||||||
|
for i, x in enumerate(inputs):
|
||||||
|
assert len(x) == 2
|
||||||
|
x, cls_token = x[0], x[1]
|
||||||
|
feature_shape = x.shape
|
||||||
|
if self.readout_type == "project":
|
||||||
|
x = x.flatten(2).permute((0, 2, 1))
|
||||||
|
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||||
|
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||||
|
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||||
|
elif self.readout_type == "add":
|
||||||
|
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||||
|
x = x.reshape(feature_shape)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
x = self.projects[i](x)
|
||||||
|
x = self.resize_layers[i](x)
|
||||||
|
out.append(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PreActResidualConvUnit(BaseModule):
|
||||||
|
"""ResidualConvUnit, pre-activate residual unit.
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of channels in the input feature map.
|
||||||
|
act_cfg (dict): dictionary to construct and config activation layer.
|
||||||
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||||
|
stride (int): stride of the first block. Default: 1
|
||||||
|
dilation (int): dilation rate for convs layers. Default: 1.
|
||||||
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
|
||||||
|
super(PreActResidualConvUnit, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.conv1 = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
3,
|
||||||
|
stride=stride,
|
||||||
|
padding=dilation,
|
||||||
|
dilation=dilation,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
bias=False,
|
||||||
|
order=("act", "conv", "norm"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
bias=False,
|
||||||
|
order=("act", "conv", "norm"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
inputs_ = inputs.clone()
|
||||||
|
x = self.conv1(inputs)
|
||||||
|
x = self.conv2(x)
|
||||||
|
return x + inputs_
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionBlock(BaseModule):
|
||||||
|
"""FeatureFusionBlock, merge feature map from different stages.
|
||||||
|
Args:
|
||||||
|
in_channels (int): Input channels.
|
||||||
|
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
expand (bool): Whether expand the channels in post process block.
|
||||||
|
Default: False.
|
||||||
|
align_corners (bool): align_corner setting for bilinear upsample.
|
||||||
|
Default: True.
|
||||||
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
|
||||||
|
super(FeatureFusionBlock, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.expand = expand
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
self.out_channels = in_channels
|
||||||
|
if self.expand:
|
||||||
|
self.out_channels = in_channels // 2
|
||||||
|
|
||||||
|
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
|
||||||
|
|
||||||
|
self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||||
|
self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||||
|
|
||||||
|
def forward(self, *inputs):
|
||||||
|
x = inputs[0]
|
||||||
|
if len(inputs) == 2:
|
||||||
|
if x.shape != inputs[1].shape:
|
||||||
|
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
||||||
|
else:
|
||||||
|
res = inputs[1]
|
||||||
|
x = x + self.res_conv_unit1(res)
|
||||||
|
x = self.res_conv_unit2(x)
|
||||||
|
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
||||||
|
x = self.project(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class DPTHead(DepthBaseDecodeHead):
|
||||||
|
"""Vision Transformers for Dense Prediction.
|
||||||
|
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||||
|
Args:
|
||||||
|
embed_dims (int): The embed dimension of the ViT backbone.
|
||||||
|
Default: 768.
|
||||||
|
post_process_channels (List): Out channels of post process conv
|
||||||
|
layers. Default: [96, 192, 384, 768].
|
||||||
|
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||||
|
patch_size (int): The patch size. Default: 16.
|
||||||
|
expand_channels (bool): Whether expand the channels in post process
|
||||||
|
block. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dims=768,
|
||||||
|
post_process_channels=[96, 192, 384, 768],
|
||||||
|
readout_type="ignore",
|
||||||
|
patch_size=16,
|
||||||
|
expand_channels=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(DPTHead, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.in_channels = self.in_channels
|
||||||
|
self.expand_channels = expand_channels
|
||||||
|
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
||||||
|
|
||||||
|
self.post_process_channels = [
|
||||||
|
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
||||||
|
]
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
for channel in self.post_process_channels:
|
||||||
|
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
|
||||||
|
self.fusion_blocks = nn.ModuleList()
|
||||||
|
for _ in range(len(self.convs)):
|
||||||
|
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
|
||||||
|
self.fusion_blocks[0].res_conv_unit1 = None
|
||||||
|
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
|
||||||
|
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||||
|
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||||
|
self.num_post_process_channels = len(self.post_process_channels)
|
||||||
|
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||||
|
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||||
|
self.conv_depth = HeadDepth(self.channels)
|
||||||
|
|
||||||
|
def forward(self, inputs, img_metas):
|
||||||
|
assert len(inputs) == self.num_reassemble_blocks
|
||||||
|
x = [inp for inp in inputs]
|
||||||
|
x = self.reassemble_blocks(x)
|
||||||
|
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||||
|
out = self.fusion_blocks[0](x[-1])
|
||||||
|
for i in range(1, len(self.fusion_blocks)):
|
||||||
|
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||||
|
out = self.project(out)
|
||||||
|
out = self.depth_pred(out)
|
||||||
|
return out
|
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...ops import resize
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .decode_head import DepthBaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class BNHead(DepthBaseDecodeHead):
|
||||||
|
"""Just a batchnorm."""
|
||||||
|
|
||||||
|
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.input_transform = input_transform
|
||||||
|
self.in_index = in_index
|
||||||
|
self.upsample = upsample
|
||||||
|
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
||||||
|
if self.classify:
|
||||||
|
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
||||||
|
else:
|
||||||
|
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
||||||
|
|
||||||
|
def _transform_inputs(self, inputs):
|
||||||
|
"""Transform inputs for decoder.
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
Returns:
|
||||||
|
Tensor: The transformed inputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "concat" in self.input_transform:
|
||||||
|
inputs = [inputs[i] for i in self.in_index]
|
||||||
|
if "resize" in self.input_transform:
|
||||||
|
inputs = [
|
||||||
|
resize(
|
||||||
|
input=x,
|
||||||
|
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=self.align_corners,
|
||||||
|
)
|
||||||
|
for x in inputs
|
||||||
|
]
|
||||||
|
inputs = torch.cat(inputs, dim=1)
|
||||||
|
elif self.input_transform == "multiple_select":
|
||||||
|
inputs = [inputs[i] for i in self.in_index]
|
||||||
|
else:
|
||||||
|
inputs = inputs[self.in_index]
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
||||||
|
"""Forward function for feature maps before classifying each pixel with
|
||||||
|
``self.cls_seg`` fc.
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
Returns:
|
||||||
|
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||||
|
H, W) which is feature map for last layer of decoder head.
|
||||||
|
"""
|
||||||
|
# accept lists (for cls token)
|
||||||
|
inputs = list(inputs)
|
||||||
|
for i, x in enumerate(inputs):
|
||||||
|
if len(x) == 2:
|
||||||
|
x, cls_token = x[0], x[1]
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x[:, :, None, None]
|
||||||
|
cls_token = cls_token[:, :, None, None].expand_as(x)
|
||||||
|
inputs[i] = torch.cat((x, cls_token), 1)
|
||||||
|
else:
|
||||||
|
x = x[0]
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x[:, :, None, None]
|
||||||
|
inputs[i] = x
|
||||||
|
x = self._transform_inputs(inputs)
|
||||||
|
# feats = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, inputs, img_metas=None, **kwargs):
|
||||||
|
"""Forward function."""
|
||||||
|
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
||||||
|
output = self.depth_pred(output)
|
||||||
|
|
||||||
|
return output
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .base import BaseDepther
|
||||||
|
from .encoder_decoder import DepthEncoderDecoder
|
@ -0,0 +1,194 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDepther(BaseModule, metaclass=ABCMeta):
|
||||||
|
"""Base class for depther."""
|
||||||
|
|
||||||
|
def __init__(self, init_cfg=None):
|
||||||
|
super(BaseDepther, self).__init__(init_cfg)
|
||||||
|
self.fp16_enabled = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def with_neck(self):
|
||||||
|
"""bool: whether the depther has neck"""
|
||||||
|
return hasattr(self, "neck") and self.neck is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def with_auxiliary_head(self):
|
||||||
|
"""bool: whether the depther has auxiliary head"""
|
||||||
|
return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def with_decode_head(self):
|
||||||
|
"""bool: whether the depther has decode head"""
|
||||||
|
return hasattr(self, "decode_head") and self.decode_head is not None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
"""Placeholder for extract features from images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode_decode(self, img, img_metas):
|
||||||
|
"""Placeholder for encode images with backbone and decode into a
|
||||||
|
semantic depth map of the same size as input."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
|
"""Placeholder for Forward function for training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def simple_test(self, img, img_meta, **kwargs):
|
||||||
|
"""Placeholder for single image test."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
|
"""Placeholder for augmentation test."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
imgs (List[Tensor]): the outer list indicates test-time
|
||||||
|
augmentations and inner Tensor should have a shape NxCxHxW,
|
||||||
|
which contains all images in the batch.
|
||||||
|
img_metas (List[List[dict]]): the outer list indicates test-time
|
||||||
|
augs (multiscale, flip, etc.) and the inner list indicates
|
||||||
|
images in a batch.
|
||||||
|
"""
|
||||||
|
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
||||||
|
if not isinstance(var, list):
|
||||||
|
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
||||||
|
num_augs = len(imgs)
|
||||||
|
if num_augs != len(img_metas):
|
||||||
|
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
||||||
|
# all images in the same aug batch all of the same ori_shape and pad
|
||||||
|
# shape
|
||||||
|
for img_meta in img_metas:
|
||||||
|
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
||||||
|
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
||||||
|
img_shapes = [_["img_shape"] for _ in img_meta]
|
||||||
|
assert all(shape == img_shapes[0] for shape in img_shapes)
|
||||||
|
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
||||||
|
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
||||||
|
|
||||||
|
if num_augs == 1:
|
||||||
|
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
||||||
|
else:
|
||||||
|
return self.aug_test(imgs, img_metas, **kwargs)
|
||||||
|
|
||||||
|
@auto_fp16(apply_to=("img",))
|
||||||
|
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
||||||
|
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
||||||
|
on whether ``return_loss`` is ``True``.
|
||||||
|
|
||||||
|
Note this setting will change the expected inputs. When
|
||||||
|
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
||||||
|
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
||||||
|
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
||||||
|
the outer list indicating test time augmentations.
|
||||||
|
"""
|
||||||
|
if return_loss:
|
||||||
|
return self.forward_train(img, img_metas, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.forward_test(img, img_metas, **kwargs)
|
||||||
|
|
||||||
|
def train_step(self, data_batch, optimizer, **kwargs):
|
||||||
|
"""The iteration step during training.
|
||||||
|
|
||||||
|
This method defines an iteration step during training, except for the
|
||||||
|
back propagation and optimizer updating, which are done in an optimizer
|
||||||
|
hook. Note that in some complicated cases or models, the whole process
|
||||||
|
including back propagation and optimizer updating is also defined in
|
||||||
|
this method, such as GAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): The output of dataloader.
|
||||||
|
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||||
|
runner is passed to ``train_step()``. This argument is unused
|
||||||
|
and reserved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
||||||
|
``num_samples``.
|
||||||
|
``loss`` is a tensor for back propagation, which can be a
|
||||||
|
weighted sum of multiple losses.
|
||||||
|
``log_vars`` contains all the variables to be sent to the
|
||||||
|
logger.
|
||||||
|
``num_samples`` indicates the batch size (when the model is
|
||||||
|
DDP, it means the batch size on each GPU), which is used for
|
||||||
|
averaging the logs.
|
||||||
|
"""
|
||||||
|
losses = self(**data_batch)
|
||||||
|
|
||||||
|
# split losses and images
|
||||||
|
real_losses = {}
|
||||||
|
log_imgs = {}
|
||||||
|
for k, v in losses.items():
|
||||||
|
if "img" in k:
|
||||||
|
log_imgs[k] = v
|
||||||
|
else:
|
||||||
|
real_losses[k] = v
|
||||||
|
|
||||||
|
loss, log_vars = self._parse_losses(real_losses)
|
||||||
|
|
||||||
|
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def val_step(self, data_batch, **kwargs):
|
||||||
|
"""The iteration step during validation.
|
||||||
|
|
||||||
|
This method shares the same signature as :func:`train_step`, but used
|
||||||
|
during val epochs. Note that the evaluation after training epochs is
|
||||||
|
not implemented with this method, but an evaluation hook.
|
||||||
|
"""
|
||||||
|
output = self(**data_batch, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_losses(losses):
|
||||||
|
"""Parse the raw outputs (losses) of the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
losses (dict): Raw output of the network, which usually contain
|
||||||
|
losses and other necessary information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
||||||
|
which may be a weighted sum of all losses, log_vars contains
|
||||||
|
all the variables to be sent to the logger.
|
||||||
|
"""
|
||||||
|
log_vars = OrderedDict()
|
||||||
|
for loss_name, loss_value in losses.items():
|
||||||
|
if isinstance(loss_value, torch.Tensor):
|
||||||
|
log_vars[loss_name] = loss_value.mean()
|
||||||
|
elif isinstance(loss_value, list):
|
||||||
|
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
||||||
|
|
||||||
|
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
||||||
|
|
||||||
|
log_vars["loss"] = loss
|
||||||
|
for loss_name, loss_value in log_vars.items():
|
||||||
|
# reduce loss when distributed training
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
loss_value = loss_value.data.clone()
|
||||||
|
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
||||||
|
log_vars[loss_name] = loss_value.item()
|
||||||
|
|
||||||
|
return loss, log_vars
|
@ -0,0 +1,236 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ...models import builder
|
||||||
|
from ...models.builder import DEPTHER
|
||||||
|
from ...ops import resize
|
||||||
|
from .base import BaseDepther
|
||||||
|
|
||||||
|
|
||||||
|
def add_prefix(inputs, prefix):
|
||||||
|
"""Add prefix for dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): The input dict with str keys.
|
||||||
|
prefix (str): The prefix to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
dict: The dict with keys updated with ``prefix``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
outputs = dict()
|
||||||
|
for name, value in inputs.items():
|
||||||
|
outputs[f"{prefix}.{name}"] = value
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@DEPTHER.register_module()
|
||||||
|
class DepthEncoderDecoder(BaseDepther):
|
||||||
|
"""Encoder Decoder depther.
|
||||||
|
|
||||||
|
EncoderDecoder typically consists of backbone, (neck) and decode_head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
|
||||||
|
super(DepthEncoderDecoder, self).__init__(init_cfg)
|
||||||
|
if pretrained is not None:
|
||||||
|
assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
|
||||||
|
backbone.pretrained = pretrained
|
||||||
|
self.backbone = builder.build_backbone(backbone)
|
||||||
|
self._init_decode_head(decode_head)
|
||||||
|
|
||||||
|
if neck is not None:
|
||||||
|
self.neck = builder.build_neck(neck)
|
||||||
|
|
||||||
|
self.train_cfg = train_cfg
|
||||||
|
self.test_cfg = test_cfg
|
||||||
|
|
||||||
|
assert self.with_decode_head
|
||||||
|
|
||||||
|
def _init_decode_head(self, decode_head):
|
||||||
|
"""Initialize ``decode_head``"""
|
||||||
|
self.decode_head = builder.build_head(decode_head)
|
||||||
|
self.align_corners = self.decode_head.align_corners
|
||||||
|
|
||||||
|
def extract_feat(self, img):
|
||||||
|
"""Extract features from images."""
|
||||||
|
x = self.backbone(img)
|
||||||
|
if self.with_neck:
|
||||||
|
x = self.neck(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
||||||
|
"""Encode images with backbone and decode into a depth estimation
|
||||||
|
map of the same size as input."""
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
out = self._decode_head_forward_test(x, img_metas)
|
||||||
|
# crop the pred depth to the certain range.
|
||||||
|
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
||||||
|
if rescale:
|
||||||
|
if size is None:
|
||||||
|
if img_metas is not None:
|
||||||
|
size = img_metas[0]["ori_shape"][:2]
|
||||||
|
else:
|
||||||
|
size = img.shape[2:]
|
||||||
|
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
||||||
|
"""Run forward function and calculate loss for decode head in
|
||||||
|
training."""
|
||||||
|
losses = dict()
|
||||||
|
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
|
||||||
|
losses.update(add_prefix(loss_decode, "decode"))
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def _decode_head_forward_test(self, x, img_metas):
|
||||||
|
"""Run forward function and calculate loss for decode head in
|
||||||
|
inference."""
|
||||||
|
depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
|
||||||
|
return depth_pred
|
||||||
|
|
||||||
|
def forward_dummy(self, img):
|
||||||
|
"""Dummy forward function."""
|
||||||
|
depth = self.encode_decode(img, None)
|
||||||
|
|
||||||
|
return depth
|
||||||
|
|
||||||
|
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
||||||
|
"""Forward function for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (Tensor): Input images.
|
||||||
|
img_metas (list[dict]): List of image info dict where each dict
|
||||||
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||||
|
For details on the values of these keys see
|
||||||
|
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||||
|
depth_gt (Tensor): Depth gt
|
||||||
|
used if the architecture supports depth estimation task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Tensor]: a dictionary of loss components
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
|
||||||
|
losses = dict()
|
||||||
|
|
||||||
|
# the last of x saves the info from neck
|
||||||
|
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
||||||
|
|
||||||
|
losses.update(loss_decode)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def whole_inference(self, img, img_meta, rescale, size=None):
|
||||||
|
"""Inference with full image."""
|
||||||
|
depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
|
||||||
|
|
||||||
|
return depth_pred
|
||||||
|
|
||||||
|
def slide_inference(self, img, img_meta, rescale):
|
||||||
|
"""Inference by sliding-window with overlap.
|
||||||
|
|
||||||
|
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||||
|
decode without padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
h_stride, w_stride = self.test_cfg.stride
|
||||||
|
h_crop, w_crop = self.test_cfg.crop_size
|
||||||
|
batch_size, _, h_img, w_img = img.size()
|
||||||
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||||
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||||
|
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||||
|
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||||
|
for h_idx in range(h_grids):
|
||||||
|
for w_idx in range(w_grids):
|
||||||
|
y1 = h_idx * h_stride
|
||||||
|
x1 = w_idx * w_stride
|
||||||
|
y2 = min(y1 + h_crop, h_img)
|
||||||
|
x2 = min(x1 + w_crop, w_img)
|
||||||
|
y1 = max(y2 - h_crop, 0)
|
||||||
|
x1 = max(x2 - w_crop, 0)
|
||||||
|
crop_img = img[:, :, y1:y2, x1:x2]
|
||||||
|
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
||||||
|
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
||||||
|
|
||||||
|
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||||
|
assert (count_mat == 0).sum() == 0
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
# cast count_mat to constant while exporting to ONNX
|
||||||
|
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
||||||
|
preds = preds / count_mat
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def inference(self, img, img_meta, rescale, size=None):
|
||||||
|
"""Inference with slide/whole style.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (Tensor): The input image of shape (N, 3, H, W).
|
||||||
|
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
||||||
|
'scale_factor', 'flip', and may also contain
|
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||||
|
For details on the values of these keys see
|
||||||
|
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||||
|
rescale (bool): Whether rescale back to original shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output depth map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.test_cfg.mode in ["slide", "whole"]
|
||||||
|
ori_shape = img_meta[0]["ori_shape"]
|
||||||
|
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
||||||
|
if self.test_cfg.mode == "slide":
|
||||||
|
depth_pred = self.slide_inference(img, img_meta, rescale)
|
||||||
|
else:
|
||||||
|
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
||||||
|
output = depth_pred
|
||||||
|
flip = img_meta[0]["flip"]
|
||||||
|
if flip:
|
||||||
|
flip_direction = img_meta[0]["flip_direction"]
|
||||||
|
assert flip_direction in ["horizontal", "vertical"]
|
||||||
|
if flip_direction == "horizontal":
|
||||||
|
output = output.flip(dims=(3,))
|
||||||
|
elif flip_direction == "vertical":
|
||||||
|
output = output.flip(dims=(2,))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def simple_test(self, img, img_meta, rescale=True):
|
||||||
|
"""Simple test with single image."""
|
||||||
|
depth_pred = self.inference(img, img_meta, rescale)
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
# our inference backend only support 4D output
|
||||||
|
depth_pred = depth_pred.unsqueeze(0)
|
||||||
|
return depth_pred
|
||||||
|
depth_pred = depth_pred.cpu().numpy()
|
||||||
|
# unravel batch dim
|
||||||
|
depth_pred = list(depth_pred)
|
||||||
|
return depth_pred
|
||||||
|
|
||||||
|
def aug_test(self, imgs, img_metas, rescale=True):
|
||||||
|
"""Test with augmentations.
|
||||||
|
|
||||||
|
Only rescale=True is supported.
|
||||||
|
"""
|
||||||
|
# aug_test rescale all imgs back to ori_shape for now
|
||||||
|
assert rescale
|
||||||
|
# to save memory, we get augmented depth logit inplace
|
||||||
|
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
||||||
|
for i in range(1, len(imgs)):
|
||||||
|
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
||||||
|
depth_pred += cur_depth_pred
|
||||||
|
depth_pred /= len(imgs)
|
||||||
|
depth_pred = depth_pred.cpu().numpy()
|
||||||
|
# unravel batch dim
|
||||||
|
depth_pred = list(depth_pred)
|
||||||
|
return depth_pred
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .gradientloss import GradientLoss
|
||||||
|
from .sigloss import SigLoss
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...models.builder import LOSSES
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class GradientLoss(nn.Module):
|
||||||
|
"""GradientLoss.
|
||||||
|
|
||||||
|
Adapted from https://www.cs.cornell.edu/projects/megadepth/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
||||||
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
||||||
|
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
|
||||||
|
super(GradientLoss, self).__init__()
|
||||||
|
self.valid_mask = valid_mask
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.loss_name = loss_name
|
||||||
|
|
||||||
|
self.eps = 0.001 # avoid grad explode
|
||||||
|
|
||||||
|
def gradientloss(self, input, target):
|
||||||
|
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
||||||
|
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
||||||
|
|
||||||
|
gradient_loss = 0
|
||||||
|
for input, target in zip(input_downscaled, target_downscaled):
|
||||||
|
if self.valid_mask:
|
||||||
|
mask = target > 0
|
||||||
|
if self.max_depth is not None:
|
||||||
|
mask = torch.logical_and(target > 0, target <= self.max_depth)
|
||||||
|
N = torch.sum(mask)
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(target)
|
||||||
|
N = input.numel()
|
||||||
|
input_log = torch.log(input + self.eps)
|
||||||
|
target_log = torch.log(target + self.eps)
|
||||||
|
log_d_diff = input_log - target_log
|
||||||
|
|
||||||
|
log_d_diff = torch.mul(log_d_diff, mask)
|
||||||
|
|
||||||
|
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
|
||||||
|
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
|
||||||
|
v_gradient = torch.mul(v_gradient, v_mask)
|
||||||
|
|
||||||
|
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
|
||||||
|
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
|
||||||
|
h_gradient = torch.mul(h_gradient, h_mask)
|
||||||
|
|
||||||
|
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
|
||||||
|
|
||||||
|
return gradient_loss
|
||||||
|
|
||||||
|
def forward(self, depth_pred, depth_gt):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
|
||||||
|
return gradient_loss
|
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...models.builder import LOSSES
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class SigLoss(nn.Module):
|
||||||
|
"""SigLoss.
|
||||||
|
|
||||||
|
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
||||||
|
loss_weight (float): Weight of the loss. Default: 1.0.
|
||||||
|
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
||||||
|
warm_up (bool): A simple warm up stage to help convergence. Default: False.
|
||||||
|
warm_iter (int): The number of warm up stage. Default: 100.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
|
||||||
|
):
|
||||||
|
super(SigLoss, self).__init__()
|
||||||
|
self.valid_mask = valid_mask
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.loss_name = loss_name
|
||||||
|
|
||||||
|
self.eps = 0.001 # avoid grad explode
|
||||||
|
|
||||||
|
# HACK: a hack implementation for warmup sigloss
|
||||||
|
self.warm_up = warm_up
|
||||||
|
self.warm_iter = warm_iter
|
||||||
|
self.warm_up_counter = 0
|
||||||
|
|
||||||
|
def sigloss(self, input, target):
|
||||||
|
if self.valid_mask:
|
||||||
|
valid_mask = target > 0
|
||||||
|
if self.max_depth is not None:
|
||||||
|
valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
|
||||||
|
input = input[valid_mask]
|
||||||
|
target = target[valid_mask]
|
||||||
|
|
||||||
|
if self.warm_up:
|
||||||
|
if self.warm_up_counter < self.warm_iter:
|
||||||
|
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
||||||
|
g = 0.15 * torch.pow(torch.mean(g), 2)
|
||||||
|
self.warm_up_counter += 1
|
||||||
|
return torch.sqrt(g)
|
||||||
|
|
||||||
|
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
||||||
|
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
|
||||||
|
return torch.sqrt(Dg)
|
||||||
|
|
||||||
|
def forward(self, depth_pred, depth_gt):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
|
||||||
|
return loss_depth
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .wrappers import resize
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
||||||
|
if warning:
|
||||||
|
if size is not None and align_corners:
|
||||||
|
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||||
|
output_h, output_w = tuple(int(x) for x in size)
|
||||||
|
if output_h > input_h or output_w > output_h:
|
||||||
|
if (
|
||||||
|
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
||||||
|
and (output_h - 1) % (input_h - 1)
|
||||||
|
and (output_w - 1) % (input_w - 1)
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
f"When align_corners={align_corners}, "
|
||||||
|
"the output would more aligned if "
|
||||||
|
f"input size {(input_h, input_w)} is `x+1` and "
|
||||||
|
f"out size {(output_h, output_w)} is `nx+1`"
|
||||||
|
)
|
||||||
|
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
404
src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/knn.py
Executable file
404
src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/knn.py
Executable file
@ -0,0 +1,404 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.functional import one_hot, softmax
|
||||||
|
|
||||||
|
import dinov2.distributed as distributed
|
||||||
|
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||||
|
from dinov2.data.transforms import make_classification_eval_transform
|
||||||
|
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
|
||||||
|
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||||
|
from dinov2.eval.setup import setup_and_build_model
|
||||||
|
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser(
|
||||||
|
description: Optional[str] = None,
|
||||||
|
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||||
|
add_help: bool = True,
|
||||||
|
):
|
||||||
|
parents = parents or []
|
||||||
|
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||||
|
parents = [setup_args_parser]
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=description,
|
||||||
|
parents=parents,
|
||||||
|
add_help=add_help,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-dataset",
|
||||||
|
dest="train_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Training dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-dataset",
|
||||||
|
dest="val_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Validation dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nb_knn",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
help="Number of NN to use. 20 is usually working the best.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
help="Temperature used in the voting coefficient",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gather-on-cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to gather the train features on cpu, slower"
|
||||||
|
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
help="Batch size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-per-class-list",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
help="Number to take per class",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-tries",
|
||||||
|
type=int,
|
||||||
|
help="Number of tries",
|
||||||
|
)
|
||||||
|
parser.set_defaults(
|
||||||
|
train_dataset_str="ImageNet:split=TRAIN",
|
||||||
|
val_dataset_str="ImageNet:split=VAL",
|
||||||
|
nb_knn=[10, 20, 100, 200],
|
||||||
|
temperature=0.07,
|
||||||
|
batch_size=256,
|
||||||
|
n_per_class_list=[-1],
|
||||||
|
n_tries=1,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class KnnModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Gets knn of test features from all processes on a chunk of the train features
|
||||||
|
|
||||||
|
Each rank gets a chunk of the train features as well as a chunk of the test features.
|
||||||
|
In `compute_neighbors`, for each rank one after the other, its chunk of test features
|
||||||
|
is sent to all devices, partial knns are computed with each chunk of train features
|
||||||
|
then collated back on the original device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.global_rank = distributed.get_global_rank()
|
||||||
|
self.global_size = distributed.get_global_size()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
|
||||||
|
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
|
||||||
|
|
||||||
|
self.nb_knn = nb_knn
|
||||||
|
self.max_k = max(self.nb_knn)
|
||||||
|
self.T = T
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
def _get_knn_sims_and_labels(self, similarity, train_labels):
|
||||||
|
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
|
||||||
|
neighbors_labels = torch.gather(train_labels, 1, indices)
|
||||||
|
return topk_sims, neighbors_labels
|
||||||
|
|
||||||
|
def _similarity_for_rank(self, features_rank, source_rank):
|
||||||
|
# Send the features from `source_rank` to all ranks
|
||||||
|
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
|
||||||
|
torch.distributed.broadcast(broadcast_shape, source_rank)
|
||||||
|
|
||||||
|
broadcasted = features_rank
|
||||||
|
if self.global_rank != source_rank:
|
||||||
|
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
|
||||||
|
torch.distributed.broadcast(broadcasted, source_rank)
|
||||||
|
|
||||||
|
# Compute the neighbors for `source_rank` among `train_features_rank_T`
|
||||||
|
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
|
||||||
|
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
|
||||||
|
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
|
||||||
|
|
||||||
|
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
|
||||||
|
# Gather all neighbors for `target_rank`
|
||||||
|
topk_sims_rank = retrieved_rank = None
|
||||||
|
if self.global_rank == target_rank:
|
||||||
|
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
|
||||||
|
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
|
||||||
|
|
||||||
|
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
|
||||||
|
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
|
||||||
|
|
||||||
|
if self.global_rank == target_rank:
|
||||||
|
# Perform a second top-k on the k * global_size retrieved neighbors
|
||||||
|
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
|
||||||
|
retrieved_rank = torch.cat(retrieved_rank, dim=1)
|
||||||
|
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
|
||||||
|
return results
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_neighbors(self, features_rank):
|
||||||
|
for rank in range(self.global_size):
|
||||||
|
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
|
||||||
|
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
|
||||||
|
if results is not None:
|
||||||
|
topk_sims_rank, neighbors_labels_rank = results
|
||||||
|
return topk_sims_rank, neighbors_labels_rank
|
||||||
|
|
||||||
|
def forward(self, features_rank):
|
||||||
|
"""
|
||||||
|
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
|
||||||
|
"""
|
||||||
|
assert all(k <= self.max_k for k in self.nb_knn)
|
||||||
|
|
||||||
|
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
|
||||||
|
batch_size = neighbors_labels.shape[0]
|
||||||
|
topk_sims_transform = softmax(topk_sims / self.T, 1)
|
||||||
|
matmul = torch.mul(
|
||||||
|
one_hot(neighbors_labels, num_classes=self.num_classes),
|
||||||
|
topk_sims_transform.view(batch_size, -1, 1),
|
||||||
|
)
|
||||||
|
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
|
||||||
|
return probas_for_k
|
||||||
|
|
||||||
|
|
||||||
|
class DictKeysModule(torch.nn.Module):
|
||||||
|
def __init__(self, keys):
|
||||||
|
super().__init__()
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def forward(self, features_dict, targets):
|
||||||
|
for k in self.keys:
|
||||||
|
features_dict = features_dict[k]
|
||||||
|
return {"preds": features_dict, "target": targets}
|
||||||
|
|
||||||
|
|
||||||
|
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
|
||||||
|
modules = {}
|
||||||
|
mapping = create_class_indices_mapping(train_labels)
|
||||||
|
for npc in n_per_class_list:
|
||||||
|
if npc < 0: # Only one try needed when using the full data
|
||||||
|
full_module = module(
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
nb_knn=nb_knn,
|
||||||
|
)
|
||||||
|
modules["full"] = ModuleDictWithForward({"1": full_module})
|
||||||
|
continue
|
||||||
|
all_tries = {}
|
||||||
|
for t in range(n_tries):
|
||||||
|
final_indices = filter_train(mapping, npc, seed=t)
|
||||||
|
k_list = list(set(nb_knn + [npc]))
|
||||||
|
k_list = sorted([el for el in k_list if el <= npc])
|
||||||
|
all_tries[str(t)] = module(
|
||||||
|
train_features=train_features[final_indices],
|
||||||
|
train_labels=train_labels[final_indices],
|
||||||
|
nb_knn=k_list,
|
||||||
|
)
|
||||||
|
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
|
||||||
|
|
||||||
|
return ModuleDictWithForward(modules)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_train(mapping, n_per_class, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
final_indices = []
|
||||||
|
for k in mapping.keys():
|
||||||
|
index = torch.randperm(len(mapping[k]))[:n_per_class]
|
||||||
|
final_indices.append(mapping[k][index])
|
||||||
|
return torch.cat(final_indices).squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def create_class_indices_mapping(labels):
|
||||||
|
unique_labels, inverse = torch.unique(labels, return_inverse=True)
|
||||||
|
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleDictWithForward(torch.nn.ModuleDict):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def eval_knn(
|
||||||
|
model,
|
||||||
|
train_dataset,
|
||||||
|
val_dataset,
|
||||||
|
accuracy_averaging,
|
||||||
|
nb_knn,
|
||||||
|
temperature,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
gather_on_cpu,
|
||||||
|
n_per_class_list=[-1],
|
||||||
|
n_tries=1,
|
||||||
|
):
|
||||||
|
model = ModelWithNormalize(model)
|
||||||
|
|
||||||
|
logger.info("Extracting features for train set...")
|
||||||
|
train_features, train_labels = extract_features(
|
||||||
|
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
|
||||||
|
)
|
||||||
|
logger.info(f"Train features created, shape {train_features.shape}.")
|
||||||
|
|
||||||
|
val_dataloader = make_data_loader(
|
||||||
|
dataset=val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
sampler_type=SamplerType.DISTRIBUTED,
|
||||||
|
drop_last=False,
|
||||||
|
shuffle=False,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
num_classes = train_labels.max() + 1
|
||||||
|
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
|
||||||
|
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
|
||||||
|
knn_module_dict = create_module_dict(
|
||||||
|
module=partial_module,
|
||||||
|
n_per_class_list=n_per_class_list,
|
||||||
|
n_tries=n_tries,
|
||||||
|
nb_knn=nb_knn,
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
)
|
||||||
|
postprocessors, metrics = {}, {}
|
||||||
|
for n_per_class, knn_module in knn_module_dict.items():
|
||||||
|
for t, knn_try in knn_module.items():
|
||||||
|
postprocessors = {
|
||||||
|
**postprocessors,
|
||||||
|
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
|
||||||
|
}
|
||||||
|
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
|
||||||
|
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
|
||||||
|
|
||||||
|
# ============ evaluation ... ============
|
||||||
|
logger.info("Start the k-NN classification.")
|
||||||
|
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
|
||||||
|
|
||||||
|
# Averaging the results over the n tries for each value of n_per_class
|
||||||
|
for n_per_class, knn_module in knn_module_dict.items():
|
||||||
|
first_try = list(knn_module.keys())[0]
|
||||||
|
k_list = knn_module[first_try].nb_knn
|
||||||
|
for k in k_list:
|
||||||
|
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
|
||||||
|
results_dict[(n_per_class, k)] = {
|
||||||
|
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
|
||||||
|
for key in keys
|
||||||
|
}
|
||||||
|
for t in knn_module.keys():
|
||||||
|
del results_dict[(n_per_class, t, k)]
|
||||||
|
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def eval_knn_with_model(
|
||||||
|
model,
|
||||||
|
output_dir,
|
||||||
|
train_dataset_str="ImageNet:split=TRAIN",
|
||||||
|
val_dataset_str="ImageNet:split=VAL",
|
||||||
|
nb_knn=(10, 20, 100, 200),
|
||||||
|
temperature=0.07,
|
||||||
|
autocast_dtype=torch.float,
|
||||||
|
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||||
|
transform=None,
|
||||||
|
gather_on_cpu=False,
|
||||||
|
batch_size=256,
|
||||||
|
num_workers=5,
|
||||||
|
n_per_class_list=[-1],
|
||||||
|
n_tries=1,
|
||||||
|
):
|
||||||
|
transform = transform or make_classification_eval_transform()
|
||||||
|
|
||||||
|
train_dataset = make_dataset(
|
||||||
|
dataset_str=train_dataset_str,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
val_dataset = make_dataset(
|
||||||
|
dataset_str=val_dataset_str,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||||
|
results_dict_knn = eval_knn(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
val_dataset=val_dataset,
|
||||||
|
accuracy_averaging=accuracy_averaging,
|
||||||
|
nb_knn=nb_knn,
|
||||||
|
temperature=temperature,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
gather_on_cpu=gather_on_cpu,
|
||||||
|
n_per_class_list=n_per_class_list,
|
||||||
|
n_tries=n_tries,
|
||||||
|
)
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
if distributed.is_main_process():
|
||||||
|
for knn_ in results_dict_knn.keys():
|
||||||
|
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
|
||||||
|
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
|
||||||
|
results_dict[f"{knn_} Top 1"] = top1
|
||||||
|
results_dict[f"{knn_} Top 5"] = top5
|
||||||
|
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
|
||||||
|
|
||||||
|
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
|
||||||
|
with open(metrics_file_path, "a") as f:
|
||||||
|
for k, v in results_dict.items():
|
||||||
|
f.write(json.dumps({k: v}) + "\n")
|
||||||
|
|
||||||
|
if distributed.is_enabled():
|
||||||
|
torch.distributed.barrier()
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model, autocast_dtype = setup_and_build_model(args)
|
||||||
|
eval_knn_with_model(
|
||||||
|
model=model,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
train_dataset_str=args.train_dataset_str,
|
||||||
|
val_dataset_str=args.val_dataset_str,
|
||||||
|
nb_knn=args.nb_knn,
|
||||||
|
temperature=args.temperature,
|
||||||
|
autocast_dtype=autocast_dtype,
|
||||||
|
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||||
|
transform=None,
|
||||||
|
gather_on_cpu=args.gather_on_cpu,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=5,
|
||||||
|
n_per_class_list=args.n_per_class_list,
|
||||||
|
n_tries=args.n_tries,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
description = "DINOv2 k-NN evaluation"
|
||||||
|
args_parser = get_args_parser(description=description)
|
||||||
|
args = args_parser.parse_args()
|
||||||
|
sys.exit(main(args))
|
@ -0,0 +1,625 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
||||||
|
|
||||||
|
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||||
|
from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform
|
||||||
|
import dinov2.distributed as distributed
|
||||||
|
from dinov2.eval.metrics import MetricType, build_metric
|
||||||
|
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||||
|
from dinov2.eval.setup import setup_and_build_model
|
||||||
|
from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate
|
||||||
|
from dinov2.logging import MetricLogger
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser(
|
||||||
|
description: Optional[str] = None,
|
||||||
|
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||||
|
add_help: bool = True,
|
||||||
|
):
|
||||||
|
parents = parents or []
|
||||||
|
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||||
|
parents = [setup_args_parser]
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=description,
|
||||||
|
parents=parents,
|
||||||
|
add_help=add_help,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-dataset",
|
||||||
|
dest="train_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Training dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-dataset",
|
||||||
|
dest="val_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Validation dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-datasets",
|
||||||
|
dest="test_dataset_strs",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="Test datasets, none to reuse the validation dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs",
|
||||||
|
type=int,
|
||||||
|
help="Number of training epochs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
help="Batch Size (per GPU)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
help="Number de Workers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch-length",
|
||||||
|
type=int,
|
||||||
|
help="Length of an epoch in number of iterations",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-checkpoint-frequency",
|
||||||
|
type=int,
|
||||||
|
help="Number of epochs between two named checkpoint saves.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-period-iterations",
|
||||||
|
type=int,
|
||||||
|
help="Number of iterations between two evaluations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rates",
|
||||||
|
nargs="+",
|
||||||
|
type=float,
|
||||||
|
help="Learning rates to grid search.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-resume",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to not resume from existing checkpoints",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-metric-type",
|
||||||
|
type=MetricType,
|
||||||
|
choices=list(MetricType),
|
||||||
|
help="Validation metric",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-metric-types",
|
||||||
|
type=MetricType,
|
||||||
|
choices=list(MetricType),
|
||||||
|
nargs="+",
|
||||||
|
help="Evaluation metric",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--classifier-fpath",
|
||||||
|
type=str,
|
||||||
|
help="Path to a file containing pretrained linear classifiers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-class-mapping-fpath",
|
||||||
|
type=str,
|
||||||
|
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-class-mapping-fpaths",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||||
|
)
|
||||||
|
parser.set_defaults(
|
||||||
|
train_dataset_str="ImageNet:split=TRAIN",
|
||||||
|
val_dataset_str="ImageNet:split=VAL",
|
||||||
|
test_dataset_strs=None,
|
||||||
|
epochs=10,
|
||||||
|
batch_size=128,
|
||||||
|
num_workers=8,
|
||||||
|
epoch_length=1250,
|
||||||
|
save_checkpoint_frequency=20,
|
||||||
|
eval_period_iterations=1250,
|
||||||
|
learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1],
|
||||||
|
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||||
|
test_metric_types=None,
|
||||||
|
classifier_fpath=None,
|
||||||
|
val_class_mapping_fpath=None,
|
||||||
|
test_class_mapping_fpaths=[None],
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def has_ddp_wrapper(m: nn.Module) -> bool:
|
||||||
|
return isinstance(m, DistributedDataParallel)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
|
||||||
|
return m.module if has_ddp_wrapper(m) else m
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_and_collate(batch):
|
||||||
|
maxlen = max(len(targets) for image, targets in batch)
|
||||||
|
padded_batch = [
|
||||||
|
(image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch
|
||||||
|
]
|
||||||
|
return torch.utils.data.default_collate(padded_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool):
|
||||||
|
intermediate_output = x_tokens_list[-use_n_blocks:]
|
||||||
|
output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
|
||||||
|
if use_avgpool:
|
||||||
|
output = torch.cat(
|
||||||
|
(
|
||||||
|
output,
|
||||||
|
torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
output = output.reshape(output.shape[0], -1)
|
||||||
|
return output.float()
|
||||||
|
|
||||||
|
|
||||||
|
class LinearClassifier(nn.Module):
|
||||||
|
"""Linear layer to train on top of frozen features"""
|
||||||
|
|
||||||
|
def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000):
|
||||||
|
super().__init__()
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.use_n_blocks = use_n_blocks
|
||||||
|
self.use_avgpool = use_avgpool
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.linear = nn.Linear(out_dim, num_classes)
|
||||||
|
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
self.linear.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x_tokens_list):
|
||||||
|
output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool)
|
||||||
|
return self.linear(output)
|
||||||
|
|
||||||
|
|
||||||
|
class AllClassifiers(nn.Module):
|
||||||
|
def __init__(self, classifiers_dict):
|
||||||
|
super().__init__()
|
||||||
|
self.classifiers_dict = nn.ModuleDict()
|
||||||
|
self.classifiers_dict.update(classifiers_dict)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.classifiers_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearPostprocessor(nn.Module):
|
||||||
|
def __init__(self, linear_classifier, class_mapping=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_classifier = linear_classifier
|
||||||
|
self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
|
||||||
|
|
||||||
|
def forward(self, samples, targets):
|
||||||
|
preds = self.linear_classifier(samples)
|
||||||
|
return {
|
||||||
|
"preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
|
||||||
|
"target": targets,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def scale_lr(learning_rates, batch_size):
|
||||||
|
return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
|
||||||
|
|
||||||
|
|
||||||
|
def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000):
|
||||||
|
linear_classifiers_dict = nn.ModuleDict()
|
||||||
|
optim_param_groups = []
|
||||||
|
for n in n_last_blocks_list:
|
||||||
|
for avgpool in [False, True]:
|
||||||
|
for _lr in learning_rates:
|
||||||
|
lr = scale_lr(_lr, batch_size)
|
||||||
|
out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
|
||||||
|
linear_classifier = LinearClassifier(
|
||||||
|
out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes
|
||||||
|
)
|
||||||
|
linear_classifier = linear_classifier.cuda()
|
||||||
|
linear_classifiers_dict[
|
||||||
|
f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_")
|
||||||
|
] = linear_classifier
|
||||||
|
optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})
|
||||||
|
|
||||||
|
linear_classifiers = AllClassifiers(linear_classifiers_dict)
|
||||||
|
if distributed.is_enabled():
|
||||||
|
linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
|
||||||
|
|
||||||
|
return linear_classifiers, optim_param_groups
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_linear_classifiers(
|
||||||
|
feature_model,
|
||||||
|
linear_classifiers,
|
||||||
|
data_loader,
|
||||||
|
metric_type,
|
||||||
|
metrics_file_path,
|
||||||
|
training_num_classes,
|
||||||
|
iteration,
|
||||||
|
prefixstring="",
|
||||||
|
class_mapping=None,
|
||||||
|
best_classifier_on_val=None,
|
||||||
|
):
|
||||||
|
logger.info("running validation !")
|
||||||
|
|
||||||
|
num_classes = len(class_mapping) if class_mapping is not None else training_num_classes
|
||||||
|
metric = build_metric(metric_type, num_classes=num_classes)
|
||||||
|
postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()}
|
||||||
|
metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
|
||||||
|
|
||||||
|
_, results_dict_temp = evaluate(
|
||||||
|
feature_model,
|
||||||
|
data_loader,
|
||||||
|
postprocessors,
|
||||||
|
metrics,
|
||||||
|
torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("")
|
||||||
|
results_dict = {}
|
||||||
|
max_accuracy = 0
|
||||||
|
best_classifier = ""
|
||||||
|
for i, (classifier_string, metric) in enumerate(results_dict_temp.items()):
|
||||||
|
logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
|
||||||
|
if (
|
||||||
|
best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
|
||||||
|
) or classifier_string == best_classifier_on_val:
|
||||||
|
max_accuracy = metric["top-1"].item()
|
||||||
|
best_classifier = classifier_string
|
||||||
|
|
||||||
|
results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
|
||||||
|
|
||||||
|
logger.info(f"best classifier: {results_dict['best_classifier']}")
|
||||||
|
|
||||||
|
if distributed.is_main_process():
|
||||||
|
with open(metrics_file_path, "a") as f:
|
||||||
|
f.write(f"iter: {iteration}\n")
|
||||||
|
for k, v in results_dict.items():
|
||||||
|
f.write(json.dumps({k: v}) + "\n")
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def eval_linear(
|
||||||
|
*,
|
||||||
|
feature_model,
|
||||||
|
linear_classifiers,
|
||||||
|
train_data_loader,
|
||||||
|
val_data_loader,
|
||||||
|
metrics_file_path,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
output_dir,
|
||||||
|
max_iter,
|
||||||
|
checkpoint_period, # In number of iter, creates a new file every period
|
||||||
|
running_checkpoint_period, # Period to update main checkpoint file
|
||||||
|
eval_period,
|
||||||
|
metric_type,
|
||||||
|
training_num_classes,
|
||||||
|
resume=True,
|
||||||
|
classifier_fpath=None,
|
||||||
|
val_class_mapping=None,
|
||||||
|
):
|
||||||
|
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||||
|
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||||
|
|
||||||
|
periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter)
|
||||||
|
iteration = start_iter
|
||||||
|
logger.info("Starting training from iteration {}".format(start_iter))
|
||||||
|
metric_logger = MetricLogger(delimiter=" ")
|
||||||
|
header = "Training"
|
||||||
|
|
||||||
|
for data, labels in metric_logger.log_every(
|
||||||
|
train_data_loader,
|
||||||
|
10,
|
||||||
|
header,
|
||||||
|
max_iter,
|
||||||
|
start_iter,
|
||||||
|
):
|
||||||
|
data = data.cuda(non_blocking=True)
|
||||||
|
labels = labels.cuda(non_blocking=True)
|
||||||
|
|
||||||
|
features = feature_model(data)
|
||||||
|
outputs = linear_classifiers(features)
|
||||||
|
|
||||||
|
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
|
||||||
|
loss = sum(losses.values())
|
||||||
|
|
||||||
|
# compute the gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# step
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# log
|
||||||
|
if iteration % 10 == 0:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||||
|
print("lr", optimizer.param_groups[0]["lr"])
|
||||||
|
|
||||||
|
if iteration - start_iter > 5:
|
||||||
|
if iteration % running_checkpoint_period == 0:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if distributed.is_main_process():
|
||||||
|
logger.info("Checkpointing running_checkpoint")
|
||||||
|
periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
periodic_checkpointer.step(iteration)
|
||||||
|
|
||||||
|
if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
|
||||||
|
_ = evaluate_linear_classifiers(
|
||||||
|
feature_model=feature_model,
|
||||||
|
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||||
|
data_loader=val_data_loader,
|
||||||
|
metrics_file_path=metrics_file_path,
|
||||||
|
prefixstring=f"ITER: {iteration}",
|
||||||
|
metric_type=metric_type,
|
||||||
|
training_num_classes=training_num_classes,
|
||||||
|
iteration=iteration,
|
||||||
|
class_mapping=val_class_mapping,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
iteration = iteration + 1
|
||||||
|
|
||||||
|
val_results_dict = evaluate_linear_classifiers(
|
||||||
|
feature_model=feature_model,
|
||||||
|
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||||
|
data_loader=val_data_loader,
|
||||||
|
metrics_file_path=metrics_file_path,
|
||||||
|
metric_type=metric_type,
|
||||||
|
training_num_classes=training_num_classes,
|
||||||
|
iteration=iteration,
|
||||||
|
class_mapping=val_class_mapping,
|
||||||
|
)
|
||||||
|
return val_results_dict, feature_model, linear_classifiers, iteration
|
||||||
|
|
||||||
|
|
||||||
|
def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type):
|
||||||
|
test_dataset = make_dataset(
|
||||||
|
dataset_str=test_dataset_str,
|
||||||
|
transform=make_classification_eval_transform(),
|
||||||
|
)
|
||||||
|
test_data_loader = make_data_loader(
|
||||||
|
dataset=test_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
sampler_type=SamplerType.DISTRIBUTED,
|
||||||
|
drop_last=False,
|
||||||
|
shuffle=False,
|
||||||
|
persistent_workers=False,
|
||||||
|
collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None,
|
||||||
|
)
|
||||||
|
return test_data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_datasets(
|
||||||
|
feature_model,
|
||||||
|
linear_classifiers,
|
||||||
|
test_dataset_strs,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
test_metric_types,
|
||||||
|
metrics_file_path,
|
||||||
|
training_num_classes,
|
||||||
|
iteration,
|
||||||
|
best_classifier_on_val,
|
||||||
|
prefixstring="",
|
||||||
|
test_class_mappings=[None],
|
||||||
|
):
|
||||||
|
results_dict = {}
|
||||||
|
for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types):
|
||||||
|
logger.info(f"Testing on {test_dataset_str}")
|
||||||
|
test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type)
|
||||||
|
dataset_results_dict = evaluate_linear_classifiers(
|
||||||
|
feature_model,
|
||||||
|
remove_ddp_wrapper(linear_classifiers),
|
||||||
|
test_data_loader,
|
||||||
|
metric_type,
|
||||||
|
metrics_file_path,
|
||||||
|
training_num_classes,
|
||||||
|
iteration,
|
||||||
|
prefixstring="",
|
||||||
|
class_mapping=class_mapping,
|
||||||
|
best_classifier_on_val=best_classifier_on_val,
|
||||||
|
)
|
||||||
|
results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"]
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval_linear(
|
||||||
|
model,
|
||||||
|
output_dir,
|
||||||
|
train_dataset_str,
|
||||||
|
val_dataset_str,
|
||||||
|
batch_size,
|
||||||
|
epochs,
|
||||||
|
epoch_length,
|
||||||
|
num_workers,
|
||||||
|
save_checkpoint_frequency,
|
||||||
|
eval_period_iterations,
|
||||||
|
learning_rates,
|
||||||
|
autocast_dtype,
|
||||||
|
test_dataset_strs=None,
|
||||||
|
resume=True,
|
||||||
|
classifier_fpath=None,
|
||||||
|
val_class_mapping_fpath=None,
|
||||||
|
test_class_mapping_fpaths=[None],
|
||||||
|
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||||
|
test_metric_types=None,
|
||||||
|
):
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
if test_dataset_strs is None:
|
||||||
|
test_dataset_strs = [val_dataset_str]
|
||||||
|
if test_metric_types is None:
|
||||||
|
test_metric_types = [val_metric_type] * len(test_dataset_strs)
|
||||||
|
else:
|
||||||
|
assert len(test_metric_types) == len(test_dataset_strs)
|
||||||
|
assert len(test_dataset_strs) == len(test_class_mapping_fpaths)
|
||||||
|
|
||||||
|
train_transform = make_classification_train_transform()
|
||||||
|
train_dataset = make_dataset(
|
||||||
|
dataset_str=train_dataset_str,
|
||||||
|
transform=train_transform,
|
||||||
|
)
|
||||||
|
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
||||||
|
sampler_type = SamplerType.SHARDED_INFINITE
|
||||||
|
# sampler_type = SamplerType.INFINITE
|
||||||
|
|
||||||
|
n_last_blocks_list = [1, 4]
|
||||||
|
n_last_blocks = max(n_last_blocks_list)
|
||||||
|
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
|
||||||
|
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
|
||||||
|
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
|
||||||
|
|
||||||
|
linear_classifiers, optim_param_groups = setup_linear_classifiers(
|
||||||
|
sample_output,
|
||||||
|
n_last_blocks_list,
|
||||||
|
learning_rates,
|
||||||
|
batch_size,
|
||||||
|
training_num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
|
||||||
|
max_iter = epochs * epoch_length
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
|
||||||
|
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||||
|
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||||
|
train_data_loader = make_data_loader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
shuffle=True,
|
||||||
|
seed=seed,
|
||||||
|
sampler_type=sampler_type,
|
||||||
|
sampler_advance=start_iter,
|
||||||
|
drop_last=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
||||||
|
|
||||||
|
checkpoint_period = save_checkpoint_frequency * epoch_length
|
||||||
|
|
||||||
|
if val_class_mapping_fpath is not None:
|
||||||
|
logger.info(f"Using class mapping from {val_class_mapping_fpath}")
|
||||||
|
val_class_mapping = np.load(val_class_mapping_fpath)
|
||||||
|
else:
|
||||||
|
val_class_mapping = None
|
||||||
|
|
||||||
|
test_class_mappings = []
|
||||||
|
for class_mapping_fpath in test_class_mapping_fpaths:
|
||||||
|
if class_mapping_fpath is not None and class_mapping_fpath != "None":
|
||||||
|
logger.info(f"Using class mapping from {class_mapping_fpath}")
|
||||||
|
class_mapping = np.load(class_mapping_fpath)
|
||||||
|
else:
|
||||||
|
class_mapping = None
|
||||||
|
test_class_mappings.append(class_mapping)
|
||||||
|
|
||||||
|
metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
|
||||||
|
val_results_dict, feature_model, linear_classifiers, iteration = eval_linear(
|
||||||
|
feature_model=feature_model,
|
||||||
|
linear_classifiers=linear_classifiers,
|
||||||
|
train_data_loader=train_data_loader,
|
||||||
|
val_data_loader=val_data_loader,
|
||||||
|
metrics_file_path=metrics_file_path,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
output_dir=output_dir,
|
||||||
|
max_iter=max_iter,
|
||||||
|
checkpoint_period=checkpoint_period,
|
||||||
|
running_checkpoint_period=epoch_length,
|
||||||
|
eval_period=eval_period_iterations,
|
||||||
|
metric_type=val_metric_type,
|
||||||
|
training_num_classes=training_num_classes,
|
||||||
|
resume=resume,
|
||||||
|
val_class_mapping=val_class_mapping,
|
||||||
|
classifier_fpath=classifier_fpath,
|
||||||
|
)
|
||||||
|
results_dict = {}
|
||||||
|
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
||||||
|
results_dict = test_on_datasets(
|
||||||
|
feature_model,
|
||||||
|
linear_classifiers,
|
||||||
|
test_dataset_strs,
|
||||||
|
batch_size,
|
||||||
|
0, # num_workers,
|
||||||
|
test_metric_types,
|
||||||
|
metrics_file_path,
|
||||||
|
training_num_classes,
|
||||||
|
iteration,
|
||||||
|
val_results_dict["best_classifier"]["name"],
|
||||||
|
prefixstring="",
|
||||||
|
test_class_mappings=test_class_mappings,
|
||||||
|
)
|
||||||
|
results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"]
|
||||||
|
results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"]
|
||||||
|
logger.info("Test Results Dict " + str(results_dict))
|
||||||
|
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model, autocast_dtype = setup_and_build_model(args)
|
||||||
|
run_eval_linear(
|
||||||
|
model=model,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
train_dataset_str=args.train_dataset_str,
|
||||||
|
val_dataset_str=args.val_dataset_str,
|
||||||
|
test_dataset_strs=args.test_dataset_strs,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
epoch_length=args.epoch_length,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
save_checkpoint_frequency=args.save_checkpoint_frequency,
|
||||||
|
eval_period_iterations=args.eval_period_iterations,
|
||||||
|
learning_rates=args.learning_rates,
|
||||||
|
autocast_dtype=autocast_dtype,
|
||||||
|
resume=not args.no_resume,
|
||||||
|
classifier_fpath=args.classifier_fpath,
|
||||||
|
val_metric_type=args.val_metric_type,
|
||||||
|
test_metric_types=args.test_metric_types,
|
||||||
|
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
||||||
|
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
description = "DINOv2 linear evaluation"
|
||||||
|
args_parser = get_args_parser(description=description)
|
||||||
|
args = args_parser.parse_args()
|
||||||
|
sys.exit(main(args))
|
@ -0,0 +1,444 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from cuml.linear_model import LogisticRegression
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
from torchmetrics import MetricTracker
|
||||||
|
|
||||||
|
from dinov2.data import make_dataset
|
||||||
|
from dinov2.data.transforms import make_classification_eval_transform
|
||||||
|
from dinov2.distributed import get_global_rank, get_global_size
|
||||||
|
from dinov2.eval.metrics import MetricType, build_metric
|
||||||
|
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||||
|
from dinov2.eval.setup import setup_and_build_model
|
||||||
|
from dinov2.eval.utils import evaluate, extract_features
|
||||||
|
from dinov2.utils.dtype import as_torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
DEFAULT_MAX_ITER = 1_000
|
||||||
|
C_POWER_RANGE = torch.linspace(-6, 5, 45)
|
||||||
|
_CPU_DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser(
|
||||||
|
description: Optional[str] = None,
|
||||||
|
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||||
|
add_help: bool = True,
|
||||||
|
):
|
||||||
|
parents = parents or []
|
||||||
|
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||||
|
parents = [setup_args_parser]
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=description,
|
||||||
|
parents=parents,
|
||||||
|
add_help=add_help,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-dataset",
|
||||||
|
dest="train_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Training dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-dataset",
|
||||||
|
dest="val_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Validation dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--finetune-dataset-str",
|
||||||
|
dest="finetune_dataset_str",
|
||||||
|
type=str,
|
||||||
|
help="Fine-tuning dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--finetune-on-val",
|
||||||
|
action="store_true",
|
||||||
|
help="If there is no finetune dataset, whether to choose the "
|
||||||
|
"hyperparameters on the val set instead of 10%% of the train dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metric-type",
|
||||||
|
type=MetricType,
|
||||||
|
choices=list(MetricType),
|
||||||
|
help="Metric type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-features-device",
|
||||||
|
type=str,
|
||||||
|
help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-dtype",
|
||||||
|
type=str,
|
||||||
|
help="Data type to convert the train features to (default: %(default)s)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-train-iters",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of train iterations (default: %(default)s)",
|
||||||
|
)
|
||||||
|
parser.set_defaults(
|
||||||
|
train_dataset_str="ImageNet:split=TRAIN",
|
||||||
|
val_dataset_str="ImageNet:split=VAL",
|
||||||
|
finetune_dataset_str=None,
|
||||||
|
metric_type=MetricType.MEAN_ACCURACY,
|
||||||
|
train_features_device="cpu",
|
||||||
|
train_dtype="float64",
|
||||||
|
max_train_iters=DEFAULT_MAX_ITER,
|
||||||
|
finetune_on_val=False,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class LogRegModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
C,
|
||||||
|
max_iter=DEFAULT_MAX_ITER,
|
||||||
|
dtype=torch.float64,
|
||||||
|
device=_CPU_DEVICE,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self.estimator = LogisticRegression(
|
||||||
|
penalty="l2",
|
||||||
|
C=C,
|
||||||
|
max_iter=max_iter,
|
||||||
|
output_type="numpy",
|
||||||
|
tol=1e-12,
|
||||||
|
linesearch_max_iter=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, samples, targets):
|
||||||
|
samples_device = samples.device
|
||||||
|
samples = samples.to(dtype=self.dtype, device=self.device)
|
||||||
|
if self.device == _CPU_DEVICE:
|
||||||
|
samples = samples.numpy()
|
||||||
|
probas = self.estimator.predict_proba(samples)
|
||||||
|
return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
|
||||||
|
|
||||||
|
def fit(self, train_features, train_labels):
|
||||||
|
train_features = train_features.to(dtype=self.dtype, device=self.device)
|
||||||
|
train_labels = train_labels.to(dtype=self.dtype, device=self.device)
|
||||||
|
if self.device == _CPU_DEVICE:
|
||||||
|
# both cuML and sklearn only work with numpy arrays on CPU
|
||||||
|
train_features = train_features.numpy()
|
||||||
|
train_labels = train_labels.numpy()
|
||||||
|
self.estimator.fit(train_features, train_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device):
|
||||||
|
postprocessors = {"metrics": logreg_model}
|
||||||
|
metrics = {"metrics": logreg_metric}
|
||||||
|
return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device)
|
||||||
|
|
||||||
|
|
||||||
|
def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE):
|
||||||
|
logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device)
|
||||||
|
logreg_model.fit(train_features, train_labels)
|
||||||
|
return logreg_model
|
||||||
|
|
||||||
|
|
||||||
|
def train_and_evaluate(
|
||||||
|
*,
|
||||||
|
C,
|
||||||
|
max_iter,
|
||||||
|
train_features,
|
||||||
|
train_labels,
|
||||||
|
logreg_metric,
|
||||||
|
test_data_loader,
|
||||||
|
train_dtype=torch.float64,
|
||||||
|
train_features_device,
|
||||||
|
eval_device,
|
||||||
|
):
|
||||||
|
logreg_model = train_for_C(
|
||||||
|
C=C,
|
||||||
|
max_iter=max_iter,
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
dtype=train_dtype,
|
||||||
|
device=train_features_device,
|
||||||
|
)
|
||||||
|
return evaluate_model(
|
||||||
|
logreg_model=logreg_model,
|
||||||
|
logreg_metric=logreg_metric,
|
||||||
|
test_data_loader=test_data_loader,
|
||||||
|
device=eval_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sweep_C_values(
|
||||||
|
*,
|
||||||
|
train_features,
|
||||||
|
train_labels,
|
||||||
|
test_data_loader,
|
||||||
|
metric_type,
|
||||||
|
num_classes,
|
||||||
|
train_dtype=torch.float64,
|
||||||
|
train_features_device=_CPU_DEVICE,
|
||||||
|
max_train_iters=DEFAULT_MAX_ITER,
|
||||||
|
):
|
||||||
|
if metric_type == MetricType.PER_CLASS_ACCURACY:
|
||||||
|
# If we want to output per-class accuracy, we select the hyperparameters with mean per class
|
||||||
|
metric_type = MetricType.MEAN_PER_CLASS_ACCURACY
|
||||||
|
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||||
|
metric_tracker = MetricTracker(logreg_metric, maximize=True)
|
||||||
|
ALL_C = 10**C_POWER_RANGE
|
||||||
|
logreg_models = {}
|
||||||
|
|
||||||
|
train_features = train_features.to(dtype=train_dtype, device=train_features_device)
|
||||||
|
train_labels = train_labels.to(device=train_features_device)
|
||||||
|
|
||||||
|
for i in range(get_global_rank(), len(ALL_C), get_global_size()):
|
||||||
|
C = ALL_C[i].item()
|
||||||
|
logger.info(
|
||||||
|
f"Training for C = {C:.5f}, dtype={train_dtype}, "
|
||||||
|
f"features: {train_features.shape}, {train_features.dtype}, "
|
||||||
|
f"labels: {train_labels.shape}, {train_labels.dtype}"
|
||||||
|
)
|
||||||
|
logreg_models[C] = train_for_C(
|
||||||
|
C=C,
|
||||||
|
max_iter=max_train_iters,
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
dtype=train_dtype,
|
||||||
|
device=train_features_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
gather_list = [None for _ in range(get_global_size())]
|
||||||
|
torch.distributed.all_gather_object(gather_list, logreg_models)
|
||||||
|
|
||||||
|
logreg_models_gathered = {}
|
||||||
|
for logreg_dict in gather_list:
|
||||||
|
logreg_models_gathered.update(logreg_dict)
|
||||||
|
|
||||||
|
for i in range(len(ALL_C)):
|
||||||
|
metric_tracker.increment()
|
||||||
|
C = ALL_C[i].item()
|
||||||
|
evals = evaluate_model(
|
||||||
|
logreg_model=logreg_models_gathered[C],
|
||||||
|
logreg_metric=metric_tracker,
|
||||||
|
test_data_loader=test_data_loader,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
)
|
||||||
|
logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}")
|
||||||
|
|
||||||
|
best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
|
||||||
|
best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
|
||||||
|
if which_epoch["top-1"] == i:
|
||||||
|
best_C = C
|
||||||
|
logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}")
|
||||||
|
|
||||||
|
return best_stats, best_C
|
||||||
|
|
||||||
|
|
||||||
|
def eval_log_regression(
|
||||||
|
*,
|
||||||
|
model,
|
||||||
|
train_dataset,
|
||||||
|
val_dataset,
|
||||||
|
finetune_dataset,
|
||||||
|
metric_type,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
finetune_on_val=False,
|
||||||
|
train_dtype=torch.float64,
|
||||||
|
train_features_device=_CPU_DEVICE,
|
||||||
|
max_train_iters=DEFAULT_MAX_ITER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Implements the "standard" process for log regression evaluation:
|
||||||
|
The value of C is chosen by training on train_dataset and evaluating on
|
||||||
|
finetune_dataset. Then, the final model is trained on a concatenation of
|
||||||
|
train_dataset and finetune_dataset, and is evaluated on val_dataset.
|
||||||
|
If there is no finetune_dataset, the value of C is the one that yields
|
||||||
|
the best results on a random 10% subset of the train dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
train_features, train_labels = extract_features(
|
||||||
|
model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||||
|
)
|
||||||
|
val_features, val_labels = extract_features(
|
||||||
|
model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||||
|
)
|
||||||
|
val_data_loader = torch.utils.data.DataLoader(
|
||||||
|
TensorDataset(val_features, val_labels),
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=0,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if finetune_dataset is None and finetune_on_val:
|
||||||
|
logger.info("Choosing hyperparameters on the val dataset")
|
||||||
|
finetune_features, finetune_labels = val_features, val_labels
|
||||||
|
elif finetune_dataset is None and not finetune_on_val:
|
||||||
|
logger.info("Choosing hyperparameters on 10% of the train dataset")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
indices = torch.randperm(len(train_features), device=train_features.device)
|
||||||
|
finetune_index = indices[: len(train_features) // 10]
|
||||||
|
train_index = indices[len(train_features) // 10 :]
|
||||||
|
finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index]
|
||||||
|
train_features, train_labels = train_features[train_index], train_labels[train_index]
|
||||||
|
else:
|
||||||
|
logger.info("Choosing hyperparameters on the finetune dataset")
|
||||||
|
finetune_features, finetune_labels = extract_features(
|
||||||
|
model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||||
|
)
|
||||||
|
# release the model - free GPU memory
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
finetune_data_loader = torch.utils.data.DataLoader(
|
||||||
|
TensorDataset(finetune_features, finetune_labels),
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(train_labels.shape) > 1:
|
||||||
|
num_classes = train_labels.shape[1]
|
||||||
|
else:
|
||||||
|
num_classes = train_labels.max() + 1
|
||||||
|
|
||||||
|
logger.info("Using cuML for logistic regression")
|
||||||
|
|
||||||
|
best_stats, best_C = sweep_C_values(
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
test_data_loader=finetune_data_loader,
|
||||||
|
metric_type=metric_type,
|
||||||
|
num_classes=num_classes,
|
||||||
|
train_dtype=train_dtype,
|
||||||
|
train_features_device=train_features_device,
|
||||||
|
max_train_iters=max_train_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not finetune_on_val:
|
||||||
|
logger.info("Best parameter found, concatenating features")
|
||||||
|
train_features = torch.cat((train_features, finetune_features))
|
||||||
|
train_labels = torch.cat((train_labels, finetune_labels))
|
||||||
|
|
||||||
|
logger.info("Training final model")
|
||||||
|
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||||
|
evals = train_and_evaluate(
|
||||||
|
C=best_C,
|
||||||
|
max_iter=max_train_iters,
|
||||||
|
train_features=train_features,
|
||||||
|
train_labels=train_labels,
|
||||||
|
logreg_metric=logreg_metric.clone(),
|
||||||
|
test_data_loader=val_data_loader,
|
||||||
|
eval_device=torch.cuda.current_device(),
|
||||||
|
train_dtype=train_dtype,
|
||||||
|
train_features_device=train_features_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_stats = evals[1]["metrics"]
|
||||||
|
|
||||||
|
best_stats["best_C"] = best_C
|
||||||
|
|
||||||
|
logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
|
||||||
|
return best_stats
|
||||||
|
|
||||||
|
|
||||||
|
def eval_log_regression_with_model(
|
||||||
|
model,
|
||||||
|
train_dataset_str="ImageNet:split=TRAIN",
|
||||||
|
val_dataset_str="ImageNet:split=VAL",
|
||||||
|
finetune_dataset_str=None,
|
||||||
|
autocast_dtype=torch.float,
|
||||||
|
finetune_on_val=False,
|
||||||
|
metric_type=MetricType.MEAN_ACCURACY,
|
||||||
|
train_dtype=torch.float64,
|
||||||
|
train_features_device=_CPU_DEVICE,
|
||||||
|
max_train_iters=DEFAULT_MAX_ITER,
|
||||||
|
):
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
transform = make_classification_eval_transform(resize_size=224)
|
||||||
|
target_transform = None
|
||||||
|
|
||||||
|
train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform)
|
||||||
|
val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform)
|
||||||
|
if finetune_dataset_str is not None:
|
||||||
|
finetune_dataset = make_dataset(
|
||||||
|
dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
finetune_dataset = None
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||||
|
results_dict_logreg = eval_log_regression(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
val_dataset=val_dataset,
|
||||||
|
finetune_dataset=finetune_dataset,
|
||||||
|
metric_type=metric_type,
|
||||||
|
batch_size=256,
|
||||||
|
num_workers=0, # 5,
|
||||||
|
finetune_on_val=finetune_on_val,
|
||||||
|
train_dtype=train_dtype,
|
||||||
|
train_features_device=train_features_device,
|
||||||
|
max_train_iters=max_train_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
results_dict = {
|
||||||
|
"top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0,
|
||||||
|
"top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0,
|
||||||
|
"best_C": results_dict_logreg["best_C"],
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"Training of the supervised logistic regression on frozen features completed.\n"
|
||||||
|
"Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]),
|
||||||
|
"Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]),
|
||||||
|
"obtained for C = {c:.6f}".format(c=results_dict["best_C"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model, autocast_dtype = setup_and_build_model(args)
|
||||||
|
eval_log_regression_with_model(
|
||||||
|
model=model,
|
||||||
|
train_dataset_str=args.train_dataset_str,
|
||||||
|
val_dataset_str=args.val_dataset_str,
|
||||||
|
finetune_dataset_str=args.finetune_dataset_str,
|
||||||
|
autocast_dtype=autocast_dtype,
|
||||||
|
finetune_on_val=args.finetune_on_val,
|
||||||
|
metric_type=args.metric_type,
|
||||||
|
train_dtype=as_torch_dtype(args.train_dtype),
|
||||||
|
train_features_device=torch.device(args.train_features_device),
|
||||||
|
max_train_iters=args.max_train_iters,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
description = "DINOv2 logistic regression evaluation"
|
||||||
|
args_parser = get_args_parser(description=description)
|
||||||
|
args = args_parser.parse_args()
|
||||||
|
sys.exit(main(args))
|
@ -0,0 +1,113 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torchmetrics import Metric, MetricCollection
|
||||||
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
from torchmetrics.utilities.data import dim_zero_cat, select_topk
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
class MetricType(Enum):
|
||||||
|
MEAN_ACCURACY = "mean_accuracy"
|
||||||
|
MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
|
||||||
|
PER_CLASS_ACCURACY = "per_class_accuracy"
|
||||||
|
IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accuracy_averaging(self):
|
||||||
|
return getattr(AccuracyAveraging, self.name, None)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class AccuracyAveraging(Enum):
|
||||||
|
MEAN_ACCURACY = "micro"
|
||||||
|
MEAN_PER_CLASS_ACCURACY = "macro"
|
||||||
|
PER_CLASS_ACCURACY = "none"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
|
||||||
|
if metric_type.accuracy_averaging is not None:
|
||||||
|
return build_topk_accuracy_metric(
|
||||||
|
average_type=metric_type.accuracy_averaging,
|
||||||
|
num_classes=num_classes,
|
||||||
|
ks=(1, 5) if ks is None else ks,
|
||||||
|
)
|
||||||
|
elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
|
||||||
|
return build_topk_imagenet_real_accuracy_metric(
|
||||||
|
num_classes=num_classes,
|
||||||
|
ks=(1, 5) if ks is None else ks,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown metric type {metric_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
|
||||||
|
metrics: Dict[str, Metric] = {
|
||||||
|
f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
|
||||||
|
}
|
||||||
|
return MetricCollection(metrics)
|
||||||
|
|
||||||
|
|
||||||
|
def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
|
||||||
|
metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
|
||||||
|
return MetricCollection(metrics)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetReaLAccuracy(Metric):
|
||||||
|
is_differentiable: bool = False
|
||||||
|
higher_is_better: Optional[bool] = None
|
||||||
|
full_state_update: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: int,
|
||||||
|
top_k: int = 1,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.top_k = top_k
|
||||||
|
self.add_state("tp", [], dist_reduce_fx="cat")
|
||||||
|
|
||||||
|
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
|
||||||
|
# preds [B, D]
|
||||||
|
# target [B, A]
|
||||||
|
# preds_oh [B, D] with 0 and 1
|
||||||
|
# select top K highest probabilities, use one hot representation
|
||||||
|
preds_oh = select_topk(preds, self.top_k)
|
||||||
|
# target_oh [B, D + 1] with 0 and 1
|
||||||
|
target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
|
||||||
|
target = target.long()
|
||||||
|
# for undefined targets (-1) use a fake value `num_classes`
|
||||||
|
target[target == -1] = self.num_classes
|
||||||
|
# fill targets, use one hot representation
|
||||||
|
target_oh.scatter_(1, target, 1)
|
||||||
|
# target_oh [B, D] (remove the fake target at index `num_classes`)
|
||||||
|
target_oh = target_oh[:, :-1]
|
||||||
|
# tp [B] with 0 and 1
|
||||||
|
tp = (preds_oh * target_oh == 1).sum(dim=1)
|
||||||
|
# at least one match between prediction and target
|
||||||
|
tp.clip_(max=1)
|
||||||
|
# ignore instances where no targets are defined
|
||||||
|
mask = target_oh.sum(dim=1) > 0
|
||||||
|
tp = tp[mask]
|
||||||
|
self.tp.append(tp) # type: ignore
|
||||||
|
|
||||||
|
def compute(self) -> Tensor:
|
||||||
|
tp = dim_zero_cat(self.tp) # type: ignore
|
||||||
|
return tp.float().mean()
|
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .optimizer import DistOptimizerHook
|
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
except ImportError:
|
||||||
|
print("apex is not installed")
|
||||||
|
|
||||||
|
from mmcv.runner import OptimizerHook, HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class DistOptimizerHook(OptimizerHook):
|
||||||
|
"""Optimizer hook for distributed training."""
|
||||||
|
|
||||||
|
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
|
||||||
|
self.grad_clip = grad_clip
|
||||||
|
self.coalesce = coalesce
|
||||||
|
self.bucket_size_mb = bucket_size_mb
|
||||||
|
self.update_interval = update_interval
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
|
||||||
|
def before_run(self, runner):
|
||||||
|
runner.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
runner.outputs["loss"] /= self.update_interval
|
||||||
|
if self.use_fp16:
|
||||||
|
# runner.outputs['loss'].backward()
|
||||||
|
with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
runner.outputs["loss"].backward()
|
||||||
|
if self.every_n_iters(runner, self.update_interval):
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
self.clip_grads(runner.model.parameters())
|
||||||
|
runner.optimizer.step()
|
||||||
|
runner.optimizer.zero_grad()
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .backbones import * # noqa: F403
|
||||||
|
from .decode_heads import * # noqa: F403
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .vision_transformer import DinoVisionTransformer
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from mmcv.runner import BaseModule
|
||||||
|
from mmseg.models.builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class DinoVisionTransformer(BaseModule):
|
||||||
|
"""Vision Transformer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .linear_head import BNHead
|
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmseg.models.builder import HEADS
|
||||||
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class BNHead(BaseDecodeHead):
|
||||||
|
"""Just a batchnorm."""
|
||||||
|
|
||||||
|
def __init__(self, resize_factors=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
assert self.in_channels == self.channels
|
||||||
|
self.bn = nn.SyncBatchNorm(self.in_channels)
|
||||||
|
self.resize_factors = resize_factors
|
||||||
|
|
||||||
|
def _forward_feature(self, inputs):
|
||||||
|
"""Forward function for feature maps before classifying each pixel with
|
||||||
|
``self.cls_seg`` fc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||||
|
H, W) which is feature map for last layer of decoder head.
|
||||||
|
"""
|
||||||
|
# print("inputs", [i.shape for i in inputs])
|
||||||
|
x = self._transform_inputs(inputs)
|
||||||
|
# print("x", x.shape)
|
||||||
|
feats = self.bn(x)
|
||||||
|
# print("feats", feats.shape)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
def _transform_inputs(self, inputs):
|
||||||
|
"""Transform inputs for decoder.
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
Returns:
|
||||||
|
Tensor: The transformed inputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.input_transform == "resize_concat":
|
||||||
|
# accept lists (for cls token)
|
||||||
|
input_list = []
|
||||||
|
for x in inputs:
|
||||||
|
if isinstance(x, list):
|
||||||
|
input_list.extend(x)
|
||||||
|
else:
|
||||||
|
input_list.append(x)
|
||||||
|
inputs = input_list
|
||||||
|
# an image descriptor can be a local descriptor with resolution 1x1
|
||||||
|
for i, x in enumerate(inputs):
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
inputs[i] = x[:, :, None, None]
|
||||||
|
# select indices
|
||||||
|
inputs = [inputs[i] for i in self.in_index]
|
||||||
|
# Resizing shenanigans
|
||||||
|
# print("before", *(x.shape for x in inputs))
|
||||||
|
if self.resize_factors is not None:
|
||||||
|
assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
|
||||||
|
inputs = [
|
||||||
|
resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
|
||||||
|
for x, f in zip(inputs, self.resize_factors)
|
||||||
|
]
|
||||||
|
# print("after", *(x.shape for x in inputs))
|
||||||
|
upsampled_inputs = [
|
||||||
|
resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
|
||||||
|
for x in inputs
|
||||||
|
]
|
||||||
|
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||||
|
elif self.input_transform == "multiple_select":
|
||||||
|
inputs = [inputs[i] for i in self.in_index]
|
||||||
|
else:
|
||||||
|
inputs = inputs[self.in_index]
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Forward function."""
|
||||||
|
output = self._forward_feature(inputs)
|
||||||
|
output = self.cls_seg(output)
|
||||||
|
return output
|
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
@ -0,0 +1,362 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
ADE20K_COLORMAP = [
|
||||||
|
(0, 0, 0),
|
||||||
|
(120, 120, 120),
|
||||||
|
(180, 120, 120),
|
||||||
|
(6, 230, 230),
|
||||||
|
(80, 50, 50),
|
||||||
|
(4, 200, 3),
|
||||||
|
(120, 120, 80),
|
||||||
|
(140, 140, 140),
|
||||||
|
(204, 5, 255),
|
||||||
|
(230, 230, 230),
|
||||||
|
(4, 250, 7),
|
||||||
|
(224, 5, 255),
|
||||||
|
(235, 255, 7),
|
||||||
|
(150, 5, 61),
|
||||||
|
(120, 120, 70),
|
||||||
|
(8, 255, 51),
|
||||||
|
(255, 6, 82),
|
||||||
|
(143, 255, 140),
|
||||||
|
(204, 255, 4),
|
||||||
|
(255, 51, 7),
|
||||||
|
(204, 70, 3),
|
||||||
|
(0, 102, 200),
|
||||||
|
(61, 230, 250),
|
||||||
|
(255, 6, 51),
|
||||||
|
(11, 102, 255),
|
||||||
|
(255, 7, 71),
|
||||||
|
(255, 9, 224),
|
||||||
|
(9, 7, 230),
|
||||||
|
(220, 220, 220),
|
||||||
|
(255, 9, 92),
|
||||||
|
(112, 9, 255),
|
||||||
|
(8, 255, 214),
|
||||||
|
(7, 255, 224),
|
||||||
|
(255, 184, 6),
|
||||||
|
(10, 255, 71),
|
||||||
|
(255, 41, 10),
|
||||||
|
(7, 255, 255),
|
||||||
|
(224, 255, 8),
|
||||||
|
(102, 8, 255),
|
||||||
|
(255, 61, 6),
|
||||||
|
(255, 194, 7),
|
||||||
|
(255, 122, 8),
|
||||||
|
(0, 255, 20),
|
||||||
|
(255, 8, 41),
|
||||||
|
(255, 5, 153),
|
||||||
|
(6, 51, 255),
|
||||||
|
(235, 12, 255),
|
||||||
|
(160, 150, 20),
|
||||||
|
(0, 163, 255),
|
||||||
|
(140, 140, 140),
|
||||||
|
(250, 10, 15),
|
||||||
|
(20, 255, 0),
|
||||||
|
(31, 255, 0),
|
||||||
|
(255, 31, 0),
|
||||||
|
(255, 224, 0),
|
||||||
|
(153, 255, 0),
|
||||||
|
(0, 0, 255),
|
||||||
|
(255, 71, 0),
|
||||||
|
(0, 235, 255),
|
||||||
|
(0, 173, 255),
|
||||||
|
(31, 0, 255),
|
||||||
|
(11, 200, 200),
|
||||||
|
(255, 82, 0),
|
||||||
|
(0, 255, 245),
|
||||||
|
(0, 61, 255),
|
||||||
|
(0, 255, 112),
|
||||||
|
(0, 255, 133),
|
||||||
|
(255, 0, 0),
|
||||||
|
(255, 163, 0),
|
||||||
|
(255, 102, 0),
|
||||||
|
(194, 255, 0),
|
||||||
|
(0, 143, 255),
|
||||||
|
(51, 255, 0),
|
||||||
|
(0, 82, 255),
|
||||||
|
(0, 255, 41),
|
||||||
|
(0, 255, 173),
|
||||||
|
(10, 0, 255),
|
||||||
|
(173, 255, 0),
|
||||||
|
(0, 255, 153),
|
||||||
|
(255, 92, 0),
|
||||||
|
(255, 0, 255),
|
||||||
|
(255, 0, 245),
|
||||||
|
(255, 0, 102),
|
||||||
|
(255, 173, 0),
|
||||||
|
(255, 0, 20),
|
||||||
|
(255, 184, 184),
|
||||||
|
(0, 31, 255),
|
||||||
|
(0, 255, 61),
|
||||||
|
(0, 71, 255),
|
||||||
|
(255, 0, 204),
|
||||||
|
(0, 255, 194),
|
||||||
|
(0, 255, 82),
|
||||||
|
(0, 10, 255),
|
||||||
|
(0, 112, 255),
|
||||||
|
(51, 0, 255),
|
||||||
|
(0, 194, 255),
|
||||||
|
(0, 122, 255),
|
||||||
|
(0, 255, 163),
|
||||||
|
(255, 153, 0),
|
||||||
|
(0, 255, 10),
|
||||||
|
(255, 112, 0),
|
||||||
|
(143, 255, 0),
|
||||||
|
(82, 0, 255),
|
||||||
|
(163, 255, 0),
|
||||||
|
(255, 235, 0),
|
||||||
|
(8, 184, 170),
|
||||||
|
(133, 0, 255),
|
||||||
|
(0, 255, 92),
|
||||||
|
(184, 0, 255),
|
||||||
|
(255, 0, 31),
|
||||||
|
(0, 184, 255),
|
||||||
|
(0, 214, 255),
|
||||||
|
(255, 0, 112),
|
||||||
|
(92, 255, 0),
|
||||||
|
(0, 224, 255),
|
||||||
|
(112, 224, 255),
|
||||||
|
(70, 184, 160),
|
||||||
|
(163, 0, 255),
|
||||||
|
(153, 0, 255),
|
||||||
|
(71, 255, 0),
|
||||||
|
(255, 0, 163),
|
||||||
|
(255, 204, 0),
|
||||||
|
(255, 0, 143),
|
||||||
|
(0, 255, 235),
|
||||||
|
(133, 255, 0),
|
||||||
|
(255, 0, 235),
|
||||||
|
(245, 0, 255),
|
||||||
|
(255, 0, 122),
|
||||||
|
(255, 245, 0),
|
||||||
|
(10, 190, 212),
|
||||||
|
(214, 255, 0),
|
||||||
|
(0, 204, 255),
|
||||||
|
(20, 0, 255),
|
||||||
|
(255, 255, 0),
|
||||||
|
(0, 153, 255),
|
||||||
|
(0, 41, 255),
|
||||||
|
(0, 255, 204),
|
||||||
|
(41, 0, 255),
|
||||||
|
(41, 255, 0),
|
||||||
|
(173, 0, 255),
|
||||||
|
(0, 245, 255),
|
||||||
|
(71, 0, 255),
|
||||||
|
(122, 0, 255),
|
||||||
|
(0, 255, 184),
|
||||||
|
(0, 92, 255),
|
||||||
|
(184, 255, 0),
|
||||||
|
(0, 133, 255),
|
||||||
|
(255, 214, 0),
|
||||||
|
(25, 194, 194),
|
||||||
|
(102, 255, 0),
|
||||||
|
(92, 0, 255),
|
||||||
|
]
|
||||||
|
|
||||||
|
ADE20K_CLASS_NAMES = [
|
||||||
|
"",
|
||||||
|
"wall",
|
||||||
|
"building;edifice",
|
||||||
|
"sky",
|
||||||
|
"floor;flooring",
|
||||||
|
"tree",
|
||||||
|
"ceiling",
|
||||||
|
"road;route",
|
||||||
|
"bed",
|
||||||
|
"windowpane;window",
|
||||||
|
"grass",
|
||||||
|
"cabinet",
|
||||||
|
"sidewalk;pavement",
|
||||||
|
"person;individual;someone;somebody;mortal;soul",
|
||||||
|
"earth;ground",
|
||||||
|
"door;double;door",
|
||||||
|
"table",
|
||||||
|
"mountain;mount",
|
||||||
|
"plant;flora;plant;life",
|
||||||
|
"curtain;drape;drapery;mantle;pall",
|
||||||
|
"chair",
|
||||||
|
"car;auto;automobile;machine;motorcar",
|
||||||
|
"water",
|
||||||
|
"painting;picture",
|
||||||
|
"sofa;couch;lounge",
|
||||||
|
"shelf",
|
||||||
|
"house",
|
||||||
|
"sea",
|
||||||
|
"mirror",
|
||||||
|
"rug;carpet;carpeting",
|
||||||
|
"field",
|
||||||
|
"armchair",
|
||||||
|
"seat",
|
||||||
|
"fence;fencing",
|
||||||
|
"desk",
|
||||||
|
"rock;stone",
|
||||||
|
"wardrobe;closet;press",
|
||||||
|
"lamp",
|
||||||
|
"bathtub;bathing;tub;bath;tub",
|
||||||
|
"railing;rail",
|
||||||
|
"cushion",
|
||||||
|
"base;pedestal;stand",
|
||||||
|
"box",
|
||||||
|
"column;pillar",
|
||||||
|
"signboard;sign",
|
||||||
|
"chest;of;drawers;chest;bureau;dresser",
|
||||||
|
"counter",
|
||||||
|
"sand",
|
||||||
|
"sink",
|
||||||
|
"skyscraper",
|
||||||
|
"fireplace;hearth;open;fireplace",
|
||||||
|
"refrigerator;icebox",
|
||||||
|
"grandstand;covered;stand",
|
||||||
|
"path",
|
||||||
|
"stairs;steps",
|
||||||
|
"runway",
|
||||||
|
"case;display;case;showcase;vitrine",
|
||||||
|
"pool;table;billiard;table;snooker;table",
|
||||||
|
"pillow",
|
||||||
|
"screen;door;screen",
|
||||||
|
"stairway;staircase",
|
||||||
|
"river",
|
||||||
|
"bridge;span",
|
||||||
|
"bookcase",
|
||||||
|
"blind;screen",
|
||||||
|
"coffee;table;cocktail;table",
|
||||||
|
"toilet;can;commode;crapper;pot;potty;stool;throne",
|
||||||
|
"flower",
|
||||||
|
"book",
|
||||||
|
"hill",
|
||||||
|
"bench",
|
||||||
|
"countertop",
|
||||||
|
"stove;kitchen;stove;range;kitchen;range;cooking;stove",
|
||||||
|
"palm;palm;tree",
|
||||||
|
"kitchen;island",
|
||||||
|
"computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
|
||||||
|
"swivel;chair",
|
||||||
|
"boat",
|
||||||
|
"bar",
|
||||||
|
"arcade;machine",
|
||||||
|
"hovel;hut;hutch;shack;shanty",
|
||||||
|
"bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
|
||||||
|
"towel",
|
||||||
|
"light;light;source",
|
||||||
|
"truck;motortruck",
|
||||||
|
"tower",
|
||||||
|
"chandelier;pendant;pendent",
|
||||||
|
"awning;sunshade;sunblind",
|
||||||
|
"streetlight;street;lamp",
|
||||||
|
"booth;cubicle;stall;kiosk",
|
||||||
|
"television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
|
||||||
|
"airplane;aeroplane;plane",
|
||||||
|
"dirt;track",
|
||||||
|
"apparel;wearing;apparel;dress;clothes",
|
||||||
|
"pole",
|
||||||
|
"land;ground;soil",
|
||||||
|
"bannister;banister;balustrade;balusters;handrail",
|
||||||
|
"escalator;moving;staircase;moving;stairway",
|
||||||
|
"ottoman;pouf;pouffe;puff;hassock",
|
||||||
|
"bottle",
|
||||||
|
"buffet;counter;sideboard",
|
||||||
|
"poster;posting;placard;notice;bill;card",
|
||||||
|
"stage",
|
||||||
|
"van",
|
||||||
|
"ship",
|
||||||
|
"fountain",
|
||||||
|
"conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
|
||||||
|
"canopy",
|
||||||
|
"washer;automatic;washer;washing;machine",
|
||||||
|
"plaything;toy",
|
||||||
|
"swimming;pool;swimming;bath;natatorium",
|
||||||
|
"stool",
|
||||||
|
"barrel;cask",
|
||||||
|
"basket;handbasket",
|
||||||
|
"waterfall;falls",
|
||||||
|
"tent;collapsible;shelter",
|
||||||
|
"bag",
|
||||||
|
"minibike;motorbike",
|
||||||
|
"cradle",
|
||||||
|
"oven",
|
||||||
|
"ball",
|
||||||
|
"food;solid;food",
|
||||||
|
"step;stair",
|
||||||
|
"tank;storage;tank",
|
||||||
|
"trade;name;brand;name;brand;marque",
|
||||||
|
"microwave;microwave;oven",
|
||||||
|
"pot;flowerpot",
|
||||||
|
"animal;animate;being;beast;brute;creature;fauna",
|
||||||
|
"bicycle;bike;wheel;cycle",
|
||||||
|
"lake",
|
||||||
|
"dishwasher;dish;washer;dishwashing;machine",
|
||||||
|
"screen;silver;screen;projection;screen",
|
||||||
|
"blanket;cover",
|
||||||
|
"sculpture",
|
||||||
|
"hood;exhaust;hood",
|
||||||
|
"sconce",
|
||||||
|
"vase",
|
||||||
|
"traffic;light;traffic;signal;stoplight",
|
||||||
|
"tray",
|
||||||
|
"ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
|
||||||
|
"fan",
|
||||||
|
"pier;wharf;wharfage;dock",
|
||||||
|
"crt;screen",
|
||||||
|
"plate",
|
||||||
|
"monitor;monitoring;device",
|
||||||
|
"bulletin;board;notice;board",
|
||||||
|
"shower",
|
||||||
|
"radiator",
|
||||||
|
"glass;drinking;glass",
|
||||||
|
"clock",
|
||||||
|
"flag",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
VOC2012_COLORMAP = [
|
||||||
|
(0, 0, 0),
|
||||||
|
(128, 0, 0),
|
||||||
|
(0, 128, 0),
|
||||||
|
(128, 128, 0),
|
||||||
|
(0, 0, 128),
|
||||||
|
(128, 0, 128),
|
||||||
|
(0, 128, 128),
|
||||||
|
(128, 128, 128),
|
||||||
|
(64, 0, 0),
|
||||||
|
(192, 0, 0),
|
||||||
|
(64, 128, 0),
|
||||||
|
(192, 128, 0),
|
||||||
|
(64, 0, 128),
|
||||||
|
(192, 0, 128),
|
||||||
|
(64, 128, 128),
|
||||||
|
(192, 128, 128),
|
||||||
|
(0, 64, 0),
|
||||||
|
(128, 64, 0),
|
||||||
|
(0, 192, 0),
|
||||||
|
(128, 192, 0),
|
||||||
|
(0, 64, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
VOC2012_CLASS_NAMES = [
|
||||||
|
"",
|
||||||
|
"aeroplane",
|
||||||
|
"bicycle",
|
||||||
|
"bird",
|
||||||
|
"boat",
|
||||||
|
"bottle",
|
||||||
|
"bus",
|
||||||
|
"car",
|
||||||
|
"cat",
|
||||||
|
"chair",
|
||||||
|
"cow",
|
||||||
|
"diningtable",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"motorbike",
|
||||||
|
"person",
|
||||||
|
"pottedplant",
|
||||||
|
"sheep",
|
||||||
|
"sofa",
|
||||||
|
"train",
|
||||||
|
"tvmonitor",
|
||||||
|
]
|
@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .core import * # noqa: F403
|
||||||
|
from .models import * # noqa: F403
|
||||||
|
from .ops import * # noqa: F403
|
@ -0,0 +1,11 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from mmseg.core.evaluation import * # noqa: F403
|
||||||
|
from mmseg.core.seg import * # noqa: F403
|
||||||
|
|
||||||
|
from .anchor import * # noqa: F403
|
||||||
|
from .box import * # noqa: F403
|
||||||
|
from .utils import * # noqa: F403
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .point_generator import MlvlPointGenerator # noqa: F403
|
@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
PRIOR_GENERATORS = Registry("Generator for anchors and points")
|
||||||
|
|
||||||
|
ANCHOR_GENERATORS = PRIOR_GENERATORS
|
||||||
|
|
||||||
|
|
||||||
|
def build_prior_generator(cfg, default_args=None):
|
||||||
|
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_anchor_generator(cfg, default_args=None):
|
||||||
|
warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
|
||||||
|
return build_prior_generator(cfg, default_args=default_args)
|
@ -0,0 +1,205 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
|
|
||||||
|
from .builder import PRIOR_GENERATORS
|
||||||
|
|
||||||
|
|
||||||
|
@PRIOR_GENERATORS.register_module()
|
||||||
|
class MlvlPointGenerator:
|
||||||
|
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
|
||||||
|
points-based detectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
||||||
|
in multiple feature levels in order (w, h).
|
||||||
|
offset (float): The offset of points, the value is normalized with
|
||||||
|
corresponding stride. Defaults to 0.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strides, offset=0.5):
|
||||||
|
self.strides = [_pair(stride) for stride in strides]
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_levels(self):
|
||||||
|
"""int: number of feature levels that the generator will be applied"""
|
||||||
|
return len(self.strides)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_base_priors(self):
|
||||||
|
"""list[int]: The number of priors (points) at a point
|
||||||
|
on the feature grid"""
|
||||||
|
return [1 for _ in range(len(self.strides))]
|
||||||
|
|
||||||
|
def _meshgrid(self, x, y, row_major=True):
|
||||||
|
yy, xx = torch.meshgrid(y, x)
|
||||||
|
if row_major:
|
||||||
|
# warning .flatten() would cause error in ONNX exporting
|
||||||
|
# have to use reshape here
|
||||||
|
return xx.reshape(-1), yy.reshape(-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return yy.reshape(-1), xx.reshape(-1)
|
||||||
|
|
||||||
|
def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
|
||||||
|
"""Generate grid points of multiple feature levels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featmap_sizes (list[tuple]): List of feature map sizes in
|
||||||
|
multiple feature levels, each size arrange as
|
||||||
|
as (h, w).
|
||||||
|
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||||||
|
device (str): The device where the anchors will be put on.
|
||||||
|
with_stride (bool): Whether to concatenate the stride to
|
||||||
|
the last dimension of points.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
list[torch.Tensor]: Points of multiple feature levels.
|
||||||
|
The sizes of each tensor should be (N, 2) when with stride is
|
||||||
|
``False``, where N = width * height, width and height
|
||||||
|
are the sizes of the corresponding feature level,
|
||||||
|
and the last dimension 2 represent (coord_x, coord_y),
|
||||||
|
otherwise the shape should be (N, 4),
|
||||||
|
and the last dimension 4 represent
|
||||||
|
(coord_x, coord_y, stride_w, stride_h).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.num_levels == len(featmap_sizes)
|
||||||
|
multi_level_priors = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
priors = self.single_level_grid_priors(
|
||||||
|
featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
|
||||||
|
)
|
||||||
|
multi_level_priors.append(priors)
|
||||||
|
return multi_level_priors
|
||||||
|
|
||||||
|
def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
|
||||||
|
"""Generate grid Points of a single level.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is usually called by method ``self.grid_priors``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featmap_size (tuple[int]): Size of the feature maps, arrange as
|
||||||
|
(h, w).
|
||||||
|
level_idx (int): The index of corresponding feature map level.
|
||||||
|
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||||||
|
device (str, optional): The device the tensor will be put on.
|
||||||
|
Defaults to 'cuda'.
|
||||||
|
with_stride (bool): Concatenate the stride to the last dimension
|
||||||
|
of points.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor: Points of single feature levels.
|
||||||
|
The shape of tensor should be (N, 2) when with stride is
|
||||||
|
``False``, where N = width * height, width and height
|
||||||
|
are the sizes of the corresponding feature level,
|
||||||
|
and the last dimension 2 represent (coord_x, coord_y),
|
||||||
|
otherwise the shape should be (N, 4),
|
||||||
|
and the last dimension 4 represent
|
||||||
|
(coord_x, coord_y, stride_w, stride_h).
|
||||||
|
"""
|
||||||
|
feat_h, feat_w = featmap_size
|
||||||
|
stride_w, stride_h = self.strides[level_idx]
|
||||||
|
shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
|
||||||
|
# keep featmap_size as Tensor instead of int, so that we
|
||||||
|
# can convert to ONNX correctly
|
||||||
|
shift_x = shift_x.to(dtype)
|
||||||
|
|
||||||
|
shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
|
||||||
|
# keep featmap_size as Tensor instead of int, so that we
|
||||||
|
# can convert to ONNX correctly
|
||||||
|
shift_y = shift_y.to(dtype)
|
||||||
|
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||||
|
if not with_stride:
|
||||||
|
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
|
||||||
|
else:
|
||||||
|
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
|
||||||
|
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
|
||||||
|
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
|
||||||
|
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
|
||||||
|
all_points = shifts.to(device)
|
||||||
|
return all_points
|
||||||
|
|
||||||
|
def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
|
||||||
|
"""Generate valid flags of points of multiple feature levels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featmap_sizes (list(tuple)): List of feature map sizes in
|
||||||
|
multiple feature levels, each size arrange as
|
||||||
|
as (h, w).
|
||||||
|
pad_shape (tuple(int)): The padded shape of the image,
|
||||||
|
arrange as (h, w).
|
||||||
|
device (str): The device where the anchors will be put on.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
list(torch.Tensor): Valid flags of points of multiple levels.
|
||||||
|
"""
|
||||||
|
assert self.num_levels == len(featmap_sizes)
|
||||||
|
multi_level_flags = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
point_stride = self.strides[i]
|
||||||
|
feat_h, feat_w = featmap_sizes[i]
|
||||||
|
h, w = pad_shape[:2]
|
||||||
|
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
|
||||||
|
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
|
||||||
|
flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
|
||||||
|
multi_level_flags.append(flags)
|
||||||
|
return multi_level_flags
|
||||||
|
|
||||||
|
def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
|
||||||
|
"""Generate the valid flags of points of a single feature map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featmap_size (tuple[int]): The size of feature maps, arrange as
|
||||||
|
as (h, w).
|
||||||
|
valid_size (tuple[int]): The valid size of the feature maps.
|
||||||
|
The size arrange as as (h, w).
|
||||||
|
device (str, optional): The device where the flags will be put on.
|
||||||
|
Defaults to 'cuda'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The valid flags of each points in a single level \
|
||||||
|
feature map.
|
||||||
|
"""
|
||||||
|
feat_h, feat_w = featmap_size
|
||||||
|
valid_h, valid_w = valid_size
|
||||||
|
assert valid_h <= feat_h and valid_w <= feat_w
|
||||||
|
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
||||||
|
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
||||||
|
valid_x[:valid_w] = 1
|
||||||
|
valid_y[:valid_h] = 1
|
||||||
|
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
||||||
|
valid = valid_xx & valid_yy
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
|
||||||
|
"""Generate sparse points according to the ``prior_idxs``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prior_idxs (Tensor): The index of corresponding anchors
|
||||||
|
in the feature map.
|
||||||
|
featmap_size (tuple[int]): feature map size arrange as (w, h).
|
||||||
|
level_idx (int): The level index of corresponding feature
|
||||||
|
map.
|
||||||
|
dtype (obj:`torch.dtype`): Date type of points. Defaults to
|
||||||
|
``torch.float32``.
|
||||||
|
device (obj:`torch.device`): The device where the points is
|
||||||
|
located.
|
||||||
|
Returns:
|
||||||
|
Tensor: Anchor with shape (N, 2), N should be equal to
|
||||||
|
the length of ``prior_idxs``. And last dimension
|
||||||
|
2 represent (coord_x, coord_y).
|
||||||
|
"""
|
||||||
|
height, width = featmap_size
|
||||||
|
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
|
||||||
|
y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
|
||||||
|
prioris = torch.stack([x, y], 1).to(dtype)
|
||||||
|
prioris = prioris.to(device)
|
||||||
|
return prioris
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .builder import * # noqa: F403
|
||||||
|
from .samplers import MaskPseudoSampler # noqa: F403
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
BBOX_SAMPLERS = Registry("bbox_sampler")
|
||||||
|
BBOX_CODERS = Registry("bbox_coder")
|
||||||
|
|
||||||
|
|
||||||
|
def build_sampler(cfg, **default_args):
|
||||||
|
"""Builder of box sampler."""
|
||||||
|
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_bbox_coder(cfg, **default_args):
|
||||||
|
"""Builder of box coder."""
|
||||||
|
return build_from_cfg(cfg, BBOX_CODERS, default_args)
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403
|
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .sampling_result import SamplingResult
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSampler(metaclass=ABCMeta):
|
||||||
|
"""Base class of samplers."""
|
||||||
|
|
||||||
|
def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
|
||||||
|
self.num = num
|
||||||
|
self.pos_fraction = pos_fraction
|
||||||
|
self.neg_pos_ub = neg_pos_ub
|
||||||
|
self.add_gt_as_proposals = add_gt_as_proposals
|
||||||
|
self.pos_sampler = self
|
||||||
|
self.neg_sampler = self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _sample_pos(self, assign_result, num_expected, **kwargs):
|
||||||
|
"""Sample positive samples."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _sample_neg(self, assign_result, num_expected, **kwargs):
|
||||||
|
"""Sample negative samples."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
|
||||||
|
"""Sample positive and negative bboxes.
|
||||||
|
|
||||||
|
This is a simple implementation of bbox sampling given candidates,
|
||||||
|
assigning results and ground truth bboxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assign_result (:obj:`AssignResult`): Bbox assigning results.
|
||||||
|
bboxes (Tensor): Boxes to be sampled from.
|
||||||
|
gt_bboxes (Tensor): Ground truth bboxes.
|
||||||
|
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`SamplingResult`: Sampling result.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from mmdet.core.bbox import RandomSampler
|
||||||
|
>>> from mmdet.core.bbox import AssignResult
|
||||||
|
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
|
||||||
|
>>> rng = ensure_rng(None)
|
||||||
|
>>> assign_result = AssignResult.random(rng=rng)
|
||||||
|
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
|
||||||
|
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
|
||||||
|
>>> gt_labels = None
|
||||||
|
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
|
||||||
|
>>> add_gt_as_proposals=False)
|
||||||
|
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
|
||||||
|
"""
|
||||||
|
if len(bboxes.shape) < 2:
|
||||||
|
bboxes = bboxes[None, :]
|
||||||
|
|
||||||
|
bboxes = bboxes[:, :4]
|
||||||
|
|
||||||
|
gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
|
||||||
|
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
|
||||||
|
if gt_labels is None:
|
||||||
|
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
|
||||||
|
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
|
||||||
|
assign_result.add_gt_(gt_labels)
|
||||||
|
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
|
||||||
|
gt_flags = torch.cat([gt_ones, gt_flags])
|
||||||
|
|
||||||
|
num_expected_pos = int(self.num * self.pos_fraction)
|
||||||
|
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
|
||||||
|
# We found that sampled indices have duplicated items occasionally.
|
||||||
|
# (may be a bug of PyTorch)
|
||||||
|
pos_inds = pos_inds.unique()
|
||||||
|
num_sampled_pos = pos_inds.numel()
|
||||||
|
num_expected_neg = self.num - num_sampled_pos
|
||||||
|
if self.neg_pos_ub >= 0:
|
||||||
|
_pos = max(1, num_sampled_pos)
|
||||||
|
neg_upper_bound = int(self.neg_pos_ub * _pos)
|
||||||
|
if num_expected_neg > neg_upper_bound:
|
||||||
|
num_expected_neg = neg_upper_bound
|
||||||
|
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
|
||||||
|
neg_inds = neg_inds.unique()
|
||||||
|
|
||||||
|
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
|
||||||
|
return sampling_result
|
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..builder import BBOX_SAMPLERS
|
||||||
|
from .base_sampler import BaseSampler
|
||||||
|
from .mask_sampling_result import MaskSamplingResult
|
||||||
|
|
||||||
|
|
||||||
|
@BBOX_SAMPLERS.register_module()
|
||||||
|
class MaskPseudoSampler(BaseSampler):
|
||||||
|
"""A pseudo sampler that does not do sampling actually."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _sample_pos(self, **kwargs):
|
||||||
|
"""Sample positive samples."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _sample_neg(self, **kwargs):
|
||||||
|
"""Sample negative samples."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sample(self, assign_result, masks, gt_masks, **kwargs):
|
||||||
|
"""Directly returns the positive and negative indices of samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assign_result (:obj:`AssignResult`): Assigned results
|
||||||
|
masks (torch.Tensor): Bounding boxes
|
||||||
|
gt_masks (torch.Tensor): Ground truth boxes
|
||||||
|
Returns:
|
||||||
|
:obj:`SamplingResult`: sampler results
|
||||||
|
"""
|
||||||
|
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
|
||||||
|
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
|
||||||
|
gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
|
||||||
|
sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags)
|
||||||
|
return sampling_result
|
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .sampling_result import SamplingResult
|
||||||
|
|
||||||
|
|
||||||
|
class MaskSamplingResult(SamplingResult):
|
||||||
|
"""Mask sampling result."""
|
||||||
|
|
||||||
|
def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags):
|
||||||
|
self.pos_inds = pos_inds
|
||||||
|
self.neg_inds = neg_inds
|
||||||
|
self.pos_masks = masks[pos_inds]
|
||||||
|
self.neg_masks = masks[neg_inds]
|
||||||
|
self.pos_is_gt = gt_flags[pos_inds]
|
||||||
|
|
||||||
|
self.num_gts = gt_masks.shape[0]
|
||||||
|
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
|
||||||
|
|
||||||
|
if gt_masks.numel() == 0:
|
||||||
|
# hack for index error case
|
||||||
|
assert self.pos_assigned_gt_inds.numel() == 0
|
||||||
|
self.pos_gt_masks = torch.empty_like(gt_masks)
|
||||||
|
else:
|
||||||
|
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
|
||||||
|
|
||||||
|
if assign_result.labels is not None:
|
||||||
|
self.pos_gt_labels = assign_result.labels[pos_inds]
|
||||||
|
else:
|
||||||
|
self.pos_gt_labels = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masks(self):
|
||||||
|
"""torch.Tensor: concatenated positive and negative boxes"""
|
||||||
|
return torch.cat([self.pos_masks, self.neg_masks])
|
||||||
|
|
||||||
|
def __nice__(self):
|
||||||
|
data = self.info.copy()
|
||||||
|
data["pos_masks"] = data.pop("pos_masks").shape
|
||||||
|
data["neg_masks"] = data.pop("neg_masks").shape
|
||||||
|
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
|
||||||
|
body = " " + ",\n ".join(parts)
|
||||||
|
return "{\n" + body + "\n}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self):
|
||||||
|
"""Returns a dictionary of info about the object."""
|
||||||
|
return {
|
||||||
|
"pos_inds": self.pos_inds,
|
||||||
|
"neg_inds": self.neg_inds,
|
||||||
|
"pos_masks": self.pos_masks,
|
||||||
|
"neg_masks": self.neg_masks,
|
||||||
|
"pos_is_gt": self.pos_is_gt,
|
||||||
|
"num_gts": self.num_gts,
|
||||||
|
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
|
||||||
|
}
|
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingResult:
|
||||||
|
"""Bbox sampling result.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # xdoctest: +IGNORE_WANT
|
||||||
|
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
|
||||||
|
>>> self = SamplingResult.random(rng=10)
|
||||||
|
>>> print(f'self = {self}')
|
||||||
|
self = <SamplingResult({
|
||||||
|
'neg_bboxes': torch.Size([12, 4]),
|
||||||
|
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
|
||||||
|
'num_gts': 4,
|
||||||
|
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
|
||||||
|
'pos_bboxes': torch.Size([0, 4]),
|
||||||
|
'pos_inds': tensor([], dtype=torch.int64),
|
||||||
|
'pos_is_gt': tensor([], dtype=torch.uint8)
|
||||||
|
})>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
|
||||||
|
self.pos_inds = pos_inds
|
||||||
|
self.neg_inds = neg_inds
|
||||||
|
self.pos_bboxes = bboxes[pos_inds]
|
||||||
|
self.neg_bboxes = bboxes[neg_inds]
|
||||||
|
self.pos_is_gt = gt_flags[pos_inds]
|
||||||
|
|
||||||
|
self.num_gts = gt_bboxes.shape[0]
|
||||||
|
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
|
||||||
|
|
||||||
|
if gt_bboxes.numel() == 0:
|
||||||
|
# hack for index error case
|
||||||
|
assert self.pos_assigned_gt_inds.numel() == 0
|
||||||
|
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
|
||||||
|
else:
|
||||||
|
if len(gt_bboxes.shape) < 2:
|
||||||
|
gt_bboxes = gt_bboxes.view(-1, 4)
|
||||||
|
|
||||||
|
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
|
||||||
|
|
||||||
|
if assign_result.labels is not None:
|
||||||
|
self.pos_gt_labels = assign_result.labels[pos_inds]
|
||||||
|
else:
|
||||||
|
self.pos_gt_labels = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bboxes(self):
|
||||||
|
"""torch.Tensor: concatenated positive and negative boxes"""
|
||||||
|
return torch.cat([self.pos_bboxes, self.neg_bboxes])
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Change the device of the data inplace.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> self = SamplingResult.random()
|
||||||
|
>>> print(f'self = {self.to(None)}')
|
||||||
|
>>> # xdoctest: +REQUIRES(--gpu)
|
||||||
|
>>> print(f'self = {self.to(0)}')
|
||||||
|
"""
|
||||||
|
_dict = self.__dict__
|
||||||
|
for key, value in _dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
_dict[key] = value.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __nice__(self):
|
||||||
|
data = self.info.copy()
|
||||||
|
data["pos_bboxes"] = data.pop("pos_bboxes").shape
|
||||||
|
data["neg_bboxes"] = data.pop("neg_bboxes").shape
|
||||||
|
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
|
||||||
|
body = " " + ",\n ".join(parts)
|
||||||
|
return "{\n" + body + "\n}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self):
|
||||||
|
"""Returns a dictionary of info about the object."""
|
||||||
|
return {
|
||||||
|
"pos_inds": self.pos_inds,
|
||||||
|
"neg_inds": self.neg_inds,
|
||||||
|
"pos_bboxes": self.pos_bboxes,
|
||||||
|
"neg_bboxes": self.neg_bboxes,
|
||||||
|
"pos_is_gt": self.pos_is_gt,
|
||||||
|
"num_gts": self.num_gts,
|
||||||
|
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def random(cls, rng=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rng (None | int | numpy.random.RandomState): seed or state.
|
||||||
|
kwargs (keyword arguments):
|
||||||
|
- num_preds: number of predicted boxes
|
||||||
|
- num_gts: number of true boxes
|
||||||
|
- p_ignore (float): probability of a predicted box assigned to \
|
||||||
|
an ignored truth.
|
||||||
|
- p_assigned (float): probability of a predicted box not being \
|
||||||
|
assigned.
|
||||||
|
- p_use_label (float | bool): with labels or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`SamplingResult`: Randomly generated sampling result.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
|
||||||
|
>>> self = SamplingResult.random()
|
||||||
|
>>> print(self.__dict__)
|
||||||
|
"""
|
||||||
|
from mmdet.core.bbox import demodata
|
||||||
|
from mmdet.core.bbox.assigners.assign_result import AssignResult
|
||||||
|
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
|
||||||
|
|
||||||
|
rng = demodata.ensure_rng(rng)
|
||||||
|
|
||||||
|
# make probabalistic?
|
||||||
|
num = 32
|
||||||
|
pos_fraction = 0.5
|
||||||
|
neg_pos_ub = -1
|
||||||
|
|
||||||
|
assign_result = AssignResult.random(rng=rng, **kwargs)
|
||||||
|
|
||||||
|
# Note we could just compute an assignment
|
||||||
|
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
|
||||||
|
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
|
||||||
|
|
||||||
|
if rng.rand() > 0.2:
|
||||||
|
# sometimes algorithms squeeze their data, be robust to that
|
||||||
|
gt_bboxes = gt_bboxes.squeeze()
|
||||||
|
bboxes = bboxes.squeeze()
|
||||||
|
|
||||||
|
if assign_result.labels is None:
|
||||||
|
gt_labels = None
|
||||||
|
else:
|
||||||
|
gt_labels = None
|
||||||
|
|
||||||
|
if gt_labels is None:
|
||||||
|
add_gt_as_proposals = False
|
||||||
|
else:
|
||||||
|
add_gt_as_proposals = True # make probabalistic?
|
||||||
|
|
||||||
|
sampler = RandomSampler(
|
||||||
|
num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng
|
||||||
|
)
|
||||||
|
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
|
||||||
|
return self
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dist_utils import reduce_mean
|
||||||
|
from .misc import add_prefix, multi_apply
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_mean(tensor):
|
||||||
|
""" "Obtain the mean of tensor on different GPUs."""
|
||||||
|
if not (dist.is_available() and dist.is_initialized()):
|
||||||
|
return tensor
|
||||||
|
tensor = tensor.clone()
|
||||||
|
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||||
|
return tensor
|
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
def multi_apply(func, *args, **kwargs):
|
||||||
|
"""Apply function to a list of arguments.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function applies the ``func`` to multiple inputs and
|
||||||
|
map the multiple outputs of the ``func`` into different
|
||||||
|
list. Each list contains the same type of outputs corresponding
|
||||||
|
to different inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Function): A function that will be applied to a list of
|
||||||
|
arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(list): A tuple containing multiple list, each list contains \
|
||||||
|
a kind of returned results by the function
|
||||||
|
"""
|
||||||
|
pfunc = partial(func, **kwargs) if kwargs else func
|
||||||
|
map_results = map(pfunc, *args)
|
||||||
|
return tuple(map(list, zip(*map_results)))
|
||||||
|
|
||||||
|
|
||||||
|
def add_prefix(inputs, prefix):
|
||||||
|
"""Add prefix for dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): The input dict with str keys.
|
||||||
|
prefix (str): The prefix to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
dict: The dict with keys updated with ``prefix``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
outputs = dict()
|
||||||
|
for name, value in inputs.items():
|
||||||
|
outputs[f"{prefix}.{name}"] = value
|
||||||
|
|
||||||
|
return outputs
|
@ -0,0 +1,11 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .backbones import * # noqa: F403
|
||||||
|
from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
|
||||||
|
from .decode_heads import * # noqa: F403
|
||||||
|
from .losses import * # noqa: F403
|
||||||
|
from .plugins import * # noqa: F403
|
||||||
|
from .segmentors import * # noqa: F403
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .vit_adapter import ViTAdapter
|
@ -0,0 +1,442 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
|
||||||
|
from ...ops.modules import MSDeformAttn
|
||||||
|
from .drop_path import DropPath
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference_points(spatial_shapes, device):
|
||||||
|
reference_points_list = []
|
||||||
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||||||
|
ref_y, ref_x = torch.meshgrid(
|
||||||
|
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
||||||
|
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
ref_y = ref_y.reshape(-1)[None] / H_
|
||||||
|
ref_x = ref_x.reshape(-1)[None] / W_
|
||||||
|
ref = torch.stack((ref_x, ref_y), -1)
|
||||||
|
reference_points_list.append(ref)
|
||||||
|
reference_points = torch.cat(reference_points_list, 1)
|
||||||
|
reference_points = reference_points[:, :, None]
|
||||||
|
return reference_points
|
||||||
|
|
||||||
|
|
||||||
|
def deform_inputs(x, patch_size):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
spatial_shapes = torch.as_tensor(
|
||||||
|
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
|
||||||
|
)
|
||||||
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||||
|
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
|
||||||
|
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
|
||||||
|
|
||||||
|
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
|
||||||
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||||
|
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
|
||||||
|
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
|
||||||
|
|
||||||
|
return deform_inputs1, deform_inputs2
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFFN(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.dwconv = DWConv(hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.dwconv(x, H, W)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DWConv(nn.Module):
|
||||||
|
def __init__(self, dim=768):
|
||||||
|
super().__init__()
|
||||||
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, N, C = x.shape
|
||||||
|
n = N // 21
|
||||||
|
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
|
||||||
|
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
|
||||||
|
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
|
||||||
|
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
|
||||||
|
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
|
||||||
|
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
|
||||||
|
x = torch.cat([x1, x2, x3], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Extractor(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=6,
|
||||||
|
n_points=4,
|
||||||
|
n_levels=1,
|
||||||
|
deform_ratio=1.0,
|
||||||
|
with_cffn=True,
|
||||||
|
cffn_ratio=0.25,
|
||||||
|
drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
with_cp=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = norm_layer(dim)
|
||||||
|
self.feat_norm = norm_layer(dim)
|
||||||
|
self.attn = MSDeformAttn(
|
||||||
|
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
||||||
|
)
|
||||||
|
self.with_cffn = with_cffn
|
||||||
|
self.with_cp = with_cp
|
||||||
|
if with_cffn:
|
||||||
|
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
|
||||||
|
self.ffn_norm = norm_layer(dim)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
|
||||||
|
def _inner_forward(query, feat):
|
||||||
|
|
||||||
|
attn = self.attn(
|
||||||
|
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
||||||
|
)
|
||||||
|
query = query + attn
|
||||||
|
|
||||||
|
if self.with_cffn:
|
||||||
|
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
|
||||||
|
return query
|
||||||
|
|
||||||
|
if self.with_cp and query.requires_grad:
|
||||||
|
query = cp.checkpoint(_inner_forward, query, feat)
|
||||||
|
else:
|
||||||
|
query = _inner_forward(query, feat)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class Injector(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=6,
|
||||||
|
n_points=4,
|
||||||
|
n_levels=1,
|
||||||
|
deform_ratio=1.0,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
init_values=0.0,
|
||||||
|
with_cp=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.query_norm = norm_layer(dim)
|
||||||
|
self.feat_norm = norm_layer(dim)
|
||||||
|
self.attn = MSDeformAttn(
|
||||||
|
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
||||||
|
)
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
|
||||||
|
def _inner_forward(query, feat):
|
||||||
|
|
||||||
|
attn = self.attn(
|
||||||
|
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
||||||
|
)
|
||||||
|
return query + self.gamma * attn
|
||||||
|
|
||||||
|
if self.with_cp and query.requires_grad:
|
||||||
|
query = cp.checkpoint(_inner_forward, query, feat)
|
||||||
|
else:
|
||||||
|
query = _inner_forward(query, feat)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=6,
|
||||||
|
n_points=4,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
with_cffn=True,
|
||||||
|
cffn_ratio=0.25,
|
||||||
|
init_values=0.0,
|
||||||
|
deform_ratio=1.0,
|
||||||
|
extra_extractor=False,
|
||||||
|
with_cp=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.injector = Injector(
|
||||||
|
dim=dim,
|
||||||
|
n_levels=3,
|
||||||
|
num_heads=num_heads,
|
||||||
|
init_values=init_values,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
self.extractor = Extractor(
|
||||||
|
dim=dim,
|
||||||
|
n_levels=1,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
with_cffn=with_cffn,
|
||||||
|
cffn_ratio=cffn_ratio,
|
||||||
|
drop=drop,
|
||||||
|
drop_path=drop_path,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
if extra_extractor:
|
||||||
|
self.extra_extractors = nn.Sequential(
|
||||||
|
*[
|
||||||
|
Extractor(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
with_cffn=with_cffn,
|
||||||
|
cffn_ratio=cffn_ratio,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
drop=drop,
|
||||||
|
drop_path=drop_path,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_extractors = None
|
||||||
|
|
||||||
|
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
||||||
|
x = self.injector(
|
||||||
|
query=x,
|
||||||
|
reference_points=deform_inputs1[0],
|
||||||
|
feat=c,
|
||||||
|
spatial_shapes=deform_inputs1[1],
|
||||||
|
level_start_index=deform_inputs1[2],
|
||||||
|
)
|
||||||
|
for idx, blk in enumerate(blocks):
|
||||||
|
x = blk(x, H_toks, W_toks)
|
||||||
|
c = self.extractor(
|
||||||
|
query=c,
|
||||||
|
reference_points=deform_inputs2[0],
|
||||||
|
feat=x,
|
||||||
|
spatial_shapes=deform_inputs2[1],
|
||||||
|
level_start_index=deform_inputs2[2],
|
||||||
|
H=H_c,
|
||||||
|
W=W_c,
|
||||||
|
)
|
||||||
|
if self.extra_extractors is not None:
|
||||||
|
for extractor in self.extra_extractors:
|
||||||
|
c = extractor(
|
||||||
|
query=c,
|
||||||
|
reference_points=deform_inputs2[0],
|
||||||
|
feat=x,
|
||||||
|
spatial_shapes=deform_inputs2[1],
|
||||||
|
level_start_index=deform_inputs2[2],
|
||||||
|
H=H_c,
|
||||||
|
W=W_c,
|
||||||
|
)
|
||||||
|
return x, c
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionBlockWithCls(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=6,
|
||||||
|
n_points=4,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
with_cffn=True,
|
||||||
|
cffn_ratio=0.25,
|
||||||
|
init_values=0.0,
|
||||||
|
deform_ratio=1.0,
|
||||||
|
extra_extractor=False,
|
||||||
|
with_cp=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.injector = Injector(
|
||||||
|
dim=dim,
|
||||||
|
n_levels=3,
|
||||||
|
num_heads=num_heads,
|
||||||
|
init_values=init_values,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
self.extractor = Extractor(
|
||||||
|
dim=dim,
|
||||||
|
n_levels=1,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
with_cffn=with_cffn,
|
||||||
|
cffn_ratio=cffn_ratio,
|
||||||
|
drop=drop,
|
||||||
|
drop_path=drop_path,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
if extra_extractor:
|
||||||
|
self.extra_extractors = nn.Sequential(
|
||||||
|
*[
|
||||||
|
Extractor(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_points=n_points,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
with_cffn=with_cffn,
|
||||||
|
cffn_ratio=cffn_ratio,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
drop=drop,
|
||||||
|
drop_path=drop_path,
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_extractors = None
|
||||||
|
|
||||||
|
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
||||||
|
x = self.injector(
|
||||||
|
query=x,
|
||||||
|
reference_points=deform_inputs1[0],
|
||||||
|
feat=c,
|
||||||
|
spatial_shapes=deform_inputs1[1],
|
||||||
|
level_start_index=deform_inputs1[2],
|
||||||
|
)
|
||||||
|
x = torch.cat((cls, x), dim=1)
|
||||||
|
for idx, blk in enumerate(blocks):
|
||||||
|
x = blk(x, H_toks, W_toks)
|
||||||
|
cls, x = (
|
||||||
|
x[
|
||||||
|
:,
|
||||||
|
:1,
|
||||||
|
],
|
||||||
|
x[
|
||||||
|
:,
|
||||||
|
1:,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
c = self.extractor(
|
||||||
|
query=c,
|
||||||
|
reference_points=deform_inputs2[0],
|
||||||
|
feat=x,
|
||||||
|
spatial_shapes=deform_inputs2[1],
|
||||||
|
level_start_index=deform_inputs2[2],
|
||||||
|
H=H_c,
|
||||||
|
W=W_c,
|
||||||
|
)
|
||||||
|
if self.extra_extractors is not None:
|
||||||
|
for extractor in self.extra_extractors:
|
||||||
|
c = extractor(
|
||||||
|
query=c,
|
||||||
|
reference_points=deform_inputs2[0],
|
||||||
|
feat=x,
|
||||||
|
spatial_shapes=deform_inputs2[1],
|
||||||
|
level_start_index=deform_inputs2[2],
|
||||||
|
H=H_c,
|
||||||
|
W=W_c,
|
||||||
|
)
|
||||||
|
return x, c, cls
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialPriorModule(nn.Module):
|
||||||
|
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
|
||||||
|
super().__init__()
|
||||||
|
self.with_cp = with_cp
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
*[
|
||||||
|
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
*[
|
||||||
|
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(2 * inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Sequential(
|
||||||
|
*[
|
||||||
|
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(4 * inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv4 = nn.Sequential(
|
||||||
|
*[
|
||||||
|
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
|
nn.SyncBatchNorm(4 * inplanes),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
def _inner_forward(x):
|
||||||
|
c1 = self.stem(x)
|
||||||
|
c2 = self.conv2(c1)
|
||||||
|
c3 = self.conv3(c2)
|
||||||
|
c4 = self.conv4(c3)
|
||||||
|
c1 = self.fc1(c1)
|
||||||
|
c2 = self.fc2(c2)
|
||||||
|
c3 = self.fc3(c3)
|
||||||
|
c4 = self.fc4(c4)
|
||||||
|
|
||||||
|
bs, dim, _, _ = c1.shape
|
||||||
|
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
|
||||||
|
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
|
||||||
|
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
|
||||||
|
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
|
||||||
|
|
||||||
|
return c1, c2, c3, c4
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
outs = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
outs = _inner_forward(x)
|
||||||
|
return outs
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob: float = 0.0):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
@ -0,0 +1,552 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Vision Transformer (ViT) in PyTorch.
|
||||||
|
|
||||||
|
A PyTorch implement of Vision Transformers as described in:
|
||||||
|
|
||||||
|
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||||
|
- https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
||||||
|
- https://arxiv.org/abs/2106.10270
|
||||||
|
|
||||||
|
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||||
|
|
||||||
|
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
||||||
|
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||||
|
|
||||||
|
Acknowledgments:
|
||||||
|
* The paper authors for releasing code and weights, thanks!
|
||||||
|
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||||
|
for some einops/einsum fun
|
||||||
|
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||||
|
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.runner import BaseModule, load_checkpoint
|
||||||
|
from mmseg.ops import resize
|
||||||
|
from mmseg.utils import get_root_logger
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .drop_path import DropPath
|
||||||
|
|
||||||
|
|
||||||
|
def to_2tuple(x):
|
||||||
|
return tuple(repeat(x, 2))
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: Optional[int] = None,
|
||||||
|
out_features: Optional[int] = None,
|
||||||
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||||
|
drop: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: Optional[int] = None,
|
||||||
|
out_features: Optional[int] = None,
|
||||||
|
act_layer: Callable[..., nn.Module] = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
swiglu_hidden_features = int(2 * hidden_features / 3)
|
||||||
|
align_as = 8
|
||||||
|
swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
|
||||||
|
self.w1 = nn.Linear(in_features, swiglu_hidden_features)
|
||||||
|
self.w2 = nn.Linear(in_features, swiglu_hidden_features)
|
||||||
|
self.w3 = nn.Linear(swiglu_hidden_features, out_features)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x1 = self.w1(x)
|
||||||
|
x2 = self.w2(x)
|
||||||
|
hidden = F.silu(x1) * x2
|
||||||
|
return self.w3(hidden)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""2D Image to Patch Embedding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||||
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
x = self.norm(x)
|
||||||
|
return x, H, W
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MemEffAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, H, W) -> Tensor:
|
||||||
|
from xformers.ops import memory_efficient_attention, unbind
|
||||||
|
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||||
|
|
||||||
|
q, k, v = unbind(qkv, 2)
|
||||||
|
|
||||||
|
x = memory_efficient_attention(q, k, v)
|
||||||
|
x = x.reshape([B, N, C])
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
window_size (int): window size
|
||||||
|
Returns:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size (int): Window size
|
||||||
|
H (int): Height of image
|
||||||
|
W (int): Width of image
|
||||||
|
Returns:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowedAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.window_size = window_size
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, N, C = x.shape
|
||||||
|
N_ = self.window_size * self.window_size
|
||||||
|
H_ = math.ceil(H / self.window_size) * self.window_size
|
||||||
|
W_ = math.ceil(W / self.window_size) * self.window_size
|
||||||
|
|
||||||
|
qkv = self.qkv(x) # [B, N, C]
|
||||||
|
qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W]
|
||||||
|
qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode)
|
||||||
|
|
||||||
|
qkv = F.unfold(
|
||||||
|
qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size)
|
||||||
|
)
|
||||||
|
B, C_kw_kw, L = qkv.shape # L - the num of windows
|
||||||
|
qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C]
|
||||||
|
qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
||||||
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
# q,k,v [B, L, num_head, N_, C/num_head]
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
|
||||||
|
# if self.mask:
|
||||||
|
# attn = attn * mask
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
|
||||||
|
# attn @ v = [B, L, num_head, N_, C/num_head]
|
||||||
|
x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L)
|
||||||
|
|
||||||
|
x = F.fold(
|
||||||
|
x,
|
||||||
|
output_size=(H_, W_),
|
||||||
|
kernel_size=(self.window_size, self.window_size),
|
||||||
|
stride=(self.window_size, self.window_size),
|
||||||
|
) # [B, C, H_, W_]
|
||||||
|
x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# class WindowedAttention(nn.Module):
|
||||||
|
# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"):
|
||||||
|
# super().__init__()
|
||||||
|
# self.num_heads = num_heads
|
||||||
|
# head_dim = dim // num_heads
|
||||||
|
# self.scale = head_dim ** -0.5
|
||||||
|
#
|
||||||
|
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
# self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
# self.proj = nn.Linear(dim, dim)
|
||||||
|
# self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
# self.window_size = window_size
|
||||||
|
# self.pad_mode = pad_mode
|
||||||
|
#
|
||||||
|
# def forward(self, x, H, W):
|
||||||
|
# B, N, C = x.shape
|
||||||
|
#
|
||||||
|
# N_ = self.window_size * self.window_size
|
||||||
|
# H_ = math.ceil(H / self.window_size) * self.window_size
|
||||||
|
# W_ = math.ceil(W / self.window_size) * self.window_size
|
||||||
|
# x = x.view(B, H, W, C)
|
||||||
|
# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode)
|
||||||
|
#
|
||||||
|
# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C
|
||||||
|
# x = x.view(-1, N_, C)
|
||||||
|
#
|
||||||
|
# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
|
||||||
|
# attn = attn.softmax(dim=-1)
|
||||||
|
# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
|
||||||
|
# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C)
|
||||||
|
#
|
||||||
|
# x = window_reverse(x, self.window_size, H_, W_)
|
||||||
|
# x = x[:, :H, :W, :].reshape(B, N, C).contiguous()
|
||||||
|
# x = self.proj(x)
|
||||||
|
# x = self.proj_drop(x)
|
||||||
|
# return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=False,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
windowed=False,
|
||||||
|
window_size=14,
|
||||||
|
pad_mode="constant",
|
||||||
|
layer_scale=False,
|
||||||
|
with_cp=False,
|
||||||
|
ffn_layer=Mlp,
|
||||||
|
memeff=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
if windowed:
|
||||||
|
self.attn = WindowedAttention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
window_size=window_size,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
)
|
||||||
|
elif memeff:
|
||||||
|
self.attn = MemEffAttention(
|
||||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
self.layer_scale = layer_scale
|
||||||
|
if layer_scale:
|
||||||
|
self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
||||||
|
self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
def _inner_forward(x):
|
||||||
|
if self.layer_scale:
|
||||||
|
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W))
|
||||||
|
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
x = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
x = _inner_forward(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TIMMVisionTransformer(BaseModule):
|
||||||
|
"""Vision Transformer.
|
||||||
|
|
||||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||||
|
- https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
||||||
|
- https://arxiv.org/abs/2012.12877
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
embed_dim=768,
|
||||||
|
depth=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop_rate=0.0,
|
||||||
|
attn_drop_rate=0.0,
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
layer_scale=True,
|
||||||
|
embed_layer=PatchEmbed,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
window_attn=False,
|
||||||
|
window_size=14,
|
||||||
|
pretrained=None,
|
||||||
|
with_cp=False,
|
||||||
|
pre_norm=False,
|
||||||
|
ffn_type="mlp",
|
||||||
|
memeff=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int, tuple): input image size
|
||||||
|
patch_size (int, tuple): patch size
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
num_classes (int): number of classes for classification head
|
||||||
|
embed_dim (int): embedding dimension
|
||||||
|
depth (int): depth of transformer
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
drop_rate (float): dropout rate
|
||||||
|
attn_drop_rate (float): attention dropout rate
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
embed_layer (nn.Module): patch embedding layer
|
||||||
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
pretrained: (str): pretrained path
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
self.num_tokens = 1
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
act_layer = act_layer or nn.GELU
|
||||||
|
self.norm_layer = norm_layer
|
||||||
|
self.act_layer = act_layer
|
||||||
|
self.pretrain_size = img_size
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn
|
||||||
|
window_size = [window_size] * depth if not isinstance(window_size, list) else window_size
|
||||||
|
logging.info("window attention:", window_attn)
|
||||||
|
logging.info("window size:", window_size)
|
||||||
|
logging.info("layer scale:", layer_scale)
|
||||||
|
|
||||||
|
self.patch_embed = embed_layer(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm
|
||||||
|
)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN}
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.Sequential(
|
||||||
|
*[
|
||||||
|
Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
windowed=window_attn[i],
|
||||||
|
window_size=window_size[i],
|
||||||
|
layer_scale=layer_scale,
|
||||||
|
with_cp=with_cp,
|
||||||
|
ffn_layer=ffn_types[ffn_type],
|
||||||
|
memeff=memeff,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.norm = norm_layer(embed_dim)
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
# For CLIP
|
||||||
|
if pre_norm:
|
||||||
|
norm_pre = norm_layer(embed_dim)
|
||||||
|
self.norm_pre = norm_pre
|
||||||
|
else:
|
||||||
|
self.norm_pre = nn.Identity()
|
||||||
|
self.init_weights(pretrained)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
logger = get_root_logger()
|
||||||
|
load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x, H, W = self.patch_embed(x)
|
||||||
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_token, x), dim=1)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
|
||||||
|
# For CLIP
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, H, W)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||||
|
"""Resize pos_embed weights.
|
||||||
|
|
||||||
|
Resize pos_embed using bicubic interpolate method.
|
||||||
|
Args:
|
||||||
|
pos_embed (torch.Tensor): Position embedding weights.
|
||||||
|
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||||
|
downsampled input image width).
|
||||||
|
pos_shape (tuple): The resolution of downsampled origin training
|
||||||
|
image.
|
||||||
|
mode (str): Algorithm used for upsampling:
|
||||||
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||||
|
``'trilinear'``. Default: ``'nearest'``
|
||||||
|
Return:
|
||||||
|
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||||
|
"""
|
||||||
|
assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
|
||||||
|
pos_h, pos_w = pos_shape
|
||||||
|
# keep dim for easy deployment
|
||||||
|
cls_token_weight = pos_embed[:, 0:1]
|
||||||
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
|
||||||
|
pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||||
|
pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||||
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||||
|
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||||
|
return pos_embed
|
@ -0,0 +1,217 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmseg.models.builder import BACKBONES
|
||||||
|
from torch.nn.init import normal_
|
||||||
|
|
||||||
|
from ...ops.modules import MSDeformAttn
|
||||||
|
from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
|
||||||
|
from .vit import TIMMVisionTransformer
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class ViTAdapter(TIMMVisionTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrain_size=224,
|
||||||
|
num_heads=12,
|
||||||
|
conv_inplane=64,
|
||||||
|
n_points=4,
|
||||||
|
deform_num_heads=6,
|
||||||
|
init_values=0.0,
|
||||||
|
interaction_indexes=None,
|
||||||
|
with_cffn=True,
|
||||||
|
cffn_ratio=0.25,
|
||||||
|
deform_ratio=1.0,
|
||||||
|
add_vit_feature=True,
|
||||||
|
pretrained=None,
|
||||||
|
use_extra_extractor=True,
|
||||||
|
freeze_vit=False,
|
||||||
|
use_cls=True,
|
||||||
|
with_cp=False,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
|
||||||
|
if freeze_vit:
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# self.num_classes = 80
|
||||||
|
self.use_cls = use_cls
|
||||||
|
if not self.use_cls:
|
||||||
|
self.cls_token = None
|
||||||
|
self.num_block = len(self.blocks)
|
||||||
|
self.pretrain_size = (pretrain_size, pretrain_size)
|
||||||
|
self.interaction_indexes = interaction_indexes
|
||||||
|
self.add_vit_feature = add_vit_feature
|
||||||
|
embed_dim = self.embed_dim
|
||||||
|
|
||||||
|
block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
|
||||||
|
|
||||||
|
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
|
||||||
|
self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
|
||||||
|
self.interactions = nn.Sequential(
|
||||||
|
*[
|
||||||
|
block_fn(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=deform_num_heads,
|
||||||
|
n_points=n_points,
|
||||||
|
init_values=init_values,
|
||||||
|
drop_path=self.drop_path_rate,
|
||||||
|
norm_layer=self.norm_layer,
|
||||||
|
with_cffn=with_cffn,
|
||||||
|
cffn_ratio=cffn_ratio,
|
||||||
|
deform_ratio=deform_ratio,
|
||||||
|
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
|
||||||
|
with_cp=with_cp,
|
||||||
|
)
|
||||||
|
for i in range(len(interaction_indexes))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
|
||||||
|
self.norm1 = nn.SyncBatchNorm(embed_dim)
|
||||||
|
self.norm2 = nn.SyncBatchNorm(embed_dim)
|
||||||
|
self.norm3 = nn.SyncBatchNorm(embed_dim)
|
||||||
|
self.norm4 = nn.SyncBatchNorm(embed_dim)
|
||||||
|
|
||||||
|
self.up.apply(self._init_weights)
|
||||||
|
self.spm.apply(self._init_weights)
|
||||||
|
self.interactions.apply(self._init_weights)
|
||||||
|
self.apply(self._init_deform_weights)
|
||||||
|
normal_(self.level_embed)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def _get_pos_embed(self, pos_embed, H, W):
|
||||||
|
pos_embed = pos_embed.reshape(
|
||||||
|
1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
|
||||||
|
).permute(0, 3, 1, 2)
|
||||||
|
pos_embed = (
|
||||||
|
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
||||||
|
.reshape(1, -1, H * W)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
def _init_deform_weights(self, m):
|
||||||
|
if isinstance(m, MSDeformAttn):
|
||||||
|
m._reset_parameters()
|
||||||
|
|
||||||
|
def _add_level_embed(self, c2, c3, c4):
|
||||||
|
c2 = c2 + self.level_embed[0]
|
||||||
|
c3 = c3 + self.level_embed[1]
|
||||||
|
c4 = c4 + self.level_embed[2]
|
||||||
|
return c2, c3, c4
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
|
||||||
|
|
||||||
|
# SPM forward
|
||||||
|
c1, c2, c3, c4 = self.spm(x)
|
||||||
|
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
|
||||||
|
c = torch.cat([c2, c3, c4], dim=1)
|
||||||
|
|
||||||
|
# Patch Embedding forward
|
||||||
|
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
|
||||||
|
x, H_toks, W_toks = self.patch_embed(x)
|
||||||
|
# print("H_toks, W_toks =", H_toks, W_toks)
|
||||||
|
bs, n, dim = x.shape
|
||||||
|
pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
|
||||||
|
if self.use_cls:
|
||||||
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_token, x), dim=1)
|
||||||
|
pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
|
||||||
|
x = self.pos_drop(x + pos_embed)
|
||||||
|
# For CLIP
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
|
# Interaction
|
||||||
|
if self.use_cls:
|
||||||
|
cls, x = (
|
||||||
|
x[
|
||||||
|
:,
|
||||||
|
:1,
|
||||||
|
],
|
||||||
|
x[
|
||||||
|
:,
|
||||||
|
1:,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
outs = list()
|
||||||
|
for i, layer in enumerate(self.interactions):
|
||||||
|
indexes = self.interaction_indexes[i]
|
||||||
|
if self.use_cls:
|
||||||
|
x, c, cls = layer(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
cls,
|
||||||
|
self.blocks[indexes[0] : indexes[-1] + 1],
|
||||||
|
deform_inputs1,
|
||||||
|
deform_inputs2,
|
||||||
|
H_c,
|
||||||
|
W_c,
|
||||||
|
H_toks,
|
||||||
|
W_toks,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x, c = layer(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
self.blocks[indexes[0] : indexes[-1] + 1],
|
||||||
|
deform_inputs1,
|
||||||
|
deform_inputs2,
|
||||||
|
H_c,
|
||||||
|
W_c,
|
||||||
|
H_toks,
|
||||||
|
W_toks,
|
||||||
|
)
|
||||||
|
outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
|
||||||
|
|
||||||
|
# Split & Reshape
|
||||||
|
c2 = c[:, 0 : c2.size(1), :]
|
||||||
|
c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
|
||||||
|
c4 = c[:, c2.size(1) + c3.size(1) :, :]
|
||||||
|
|
||||||
|
c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
|
||||||
|
c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
|
||||||
|
c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
|
||||||
|
c1 = self.up(c2) + c1
|
||||||
|
|
||||||
|
if self.add_vit_feature:
|
||||||
|
x1, x2, x3, x4 = outs
|
||||||
|
|
||||||
|
x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
|
||||||
|
x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
|
||||||
|
x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
|
||||||
|
x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
|
||||||
|
# print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
|
||||||
|
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
|
||||||
|
|
||||||
|
# Final Norm
|
||||||
|
f1 = self.norm1(c1)
|
||||||
|
f2 = self.norm2(c2)
|
||||||
|
f3 = self.norm3(c3)
|
||||||
|
f4 = self.norm4(c4)
|
||||||
|
return [f1, f2, f3, f4]
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from mmcv.utils import Registry
|
||||||
|
|
||||||
|
TRANSFORMER = Registry("Transformer")
|
||||||
|
MASK_ASSIGNERS = Registry("mask_assigner")
|
||||||
|
MATCH_COST = Registry("match_cost")
|
||||||
|
|
||||||
|
|
||||||
|
def build_match_cost(cfg):
|
||||||
|
"""Build Match Cost."""
|
||||||
|
return MATCH_COST.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_assigner(cfg):
|
||||||
|
"""Build Assigner."""
|
||||||
|
return MASK_ASSIGNERS.build(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def build_transformer(cfg):
|
||||||
|
"""Build Transformer."""
|
||||||
|
return TRANSFORMER.build(cfg)
|
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .mask2former_head import Mask2FormerHead
|
@ -0,0 +1,544 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
|
||||||
|
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
|
||||||
|
from mmcv.ops import point_sample
|
||||||
|
from mmcv.runner import ModuleList, force_fp32
|
||||||
|
from mmseg.models.builder import HEADS, build_loss
|
||||||
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
from ...core import build_sampler, multi_apply, reduce_mean
|
||||||
|
from ..builder import build_assigner
|
||||||
|
from ..utils import get_uncertain_point_coords_with_randomness
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class Mask2FormerHead(BaseDecodeHead):
|
||||||
|
"""Implements the Mask2Former head.
|
||||||
|
|
||||||
|
See `Masked-attention Mask Transformer for Universal Image
|
||||||
|
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (list[int]): Number of channels in the input feature map.
|
||||||
|
feat_channels (int): Number of channels for features.
|
||||||
|
out_channels (int): Number of channels for output.
|
||||||
|
num_things_classes (int): Number of things.
|
||||||
|
num_stuff_classes (int): Number of stuff.
|
||||||
|
num_queries (int): Number of query in Transformer decoder.
|
||||||
|
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
|
||||||
|
decoder. Defaults to None.
|
||||||
|
enforce_decoder_input_project (bool, optional): Whether to add
|
||||||
|
a layer to change the embed_dim of tranformer encoder in
|
||||||
|
pixel decoder to the embed_dim of transformer decoder.
|
||||||
|
Defaults to False.
|
||||||
|
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
|
||||||
|
transformer decoder. Defaults to None.
|
||||||
|
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
|
||||||
|
transformer decoder position encoding. Defaults to None.
|
||||||
|
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
|
||||||
|
loss. Defaults to None.
|
||||||
|
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
|
||||||
|
Defaults to None.
|
||||||
|
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
|
||||||
|
Defaults to None.
|
||||||
|
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
|
||||||
|
Mask2Former head.
|
||||||
|
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
|
||||||
|
Mask2Former head.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
feat_channels,
|
||||||
|
out_channels,
|
||||||
|
num_things_classes=80,
|
||||||
|
num_stuff_classes=53,
|
||||||
|
num_queries=100,
|
||||||
|
num_transformer_feat_level=3,
|
||||||
|
pixel_decoder=None,
|
||||||
|
enforce_decoder_input_project=False,
|
||||||
|
transformer_decoder=None,
|
||||||
|
positional_encoding=None,
|
||||||
|
loss_cls=None,
|
||||||
|
loss_mask=None,
|
||||||
|
loss_dice=None,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None,
|
||||||
|
init_cfg=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(Mask2FormerHead, self).__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
channels=feat_channels,
|
||||||
|
num_classes=(num_things_classes + num_stuff_classes),
|
||||||
|
init_cfg=init_cfg,
|
||||||
|
input_transform="multiple_select",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.num_things_classes = num_things_classes
|
||||||
|
self.num_stuff_classes = num_stuff_classes
|
||||||
|
self.num_classes = self.num_things_classes + self.num_stuff_classes
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.num_transformer_feat_level = num_transformer_feat_level
|
||||||
|
self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
|
||||||
|
self.num_transformer_decoder_layers = transformer_decoder.num_layers
|
||||||
|
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
|
||||||
|
pixel_decoder_ = copy.deepcopy(pixel_decoder)
|
||||||
|
pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
|
||||||
|
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
|
||||||
|
self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
|
||||||
|
self.decoder_embed_dims = self.transformer_decoder.embed_dims
|
||||||
|
|
||||||
|
self.decoder_input_projs = ModuleList()
|
||||||
|
# from low resolution to high resolution
|
||||||
|
for _ in range(num_transformer_feat_level):
|
||||||
|
if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
|
||||||
|
self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
|
||||||
|
else:
|
||||||
|
self.decoder_input_projs.append(nn.Identity())
|
||||||
|
self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
|
||||||
|
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
|
||||||
|
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
|
||||||
|
# from low resolution to high resolution
|
||||||
|
self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
|
||||||
|
|
||||||
|
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||||
|
self.mask_embed = nn.Sequential(
|
||||||
|
nn.Linear(feat_channels, feat_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(feat_channels, feat_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(feat_channels, out_channels),
|
||||||
|
)
|
||||||
|
self.conv_seg = None # fix a bug here (conv_seg is not used)
|
||||||
|
|
||||||
|
self.test_cfg = test_cfg
|
||||||
|
self.train_cfg = train_cfg
|
||||||
|
if train_cfg:
|
||||||
|
self.assigner = build_assigner(self.train_cfg.assigner)
|
||||||
|
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
|
||||||
|
self.num_points = self.train_cfg.get("num_points", 12544)
|
||||||
|
self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
|
||||||
|
self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
|
||||||
|
|
||||||
|
self.class_weight = loss_cls.class_weight
|
||||||
|
self.loss_cls = build_loss(loss_cls)
|
||||||
|
self.loss_mask = build_loss(loss_mask)
|
||||||
|
self.loss_dice = build_loss(loss_dice)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.decoder_input_projs:
|
||||||
|
if isinstance(m, Conv2d):
|
||||||
|
caffe2_xavier_init(m, bias=0)
|
||||||
|
|
||||||
|
self.pixel_decoder.init_weights()
|
||||||
|
|
||||||
|
for p in self.transformer_decoder.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_normal_(p)
|
||||||
|
|
||||||
|
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
|
||||||
|
"""Compute classification and mask targets for all images for a decoder
|
||||||
|
layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls_scores_list (list[Tensor]): Mask score logits from a single
|
||||||
|
decoder layer for all images. Each with shape [num_queries,
|
||||||
|
cls_out_channels].
|
||||||
|
mask_preds_list (list[Tensor]): Mask logits from a single decoder
|
||||||
|
layer for all images. Each with shape [num_queries, h, w].
|
||||||
|
gt_labels_list (list[Tensor]): Ground truth class indices for all
|
||||||
|
images. Each with shape (n, ), n is the sum of number of stuff
|
||||||
|
type and number of instance in a image.
|
||||||
|
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
||||||
|
each with shape (n, h, w).
|
||||||
|
img_metas (list[dict]): List of image meta information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list[Tensor]]: a tuple containing the following targets.
|
||||||
|
|
||||||
|
- labels_list (list[Tensor]): Labels of all images.
|
||||||
|
Each with shape [num_queries, ].
|
||||||
|
- label_weights_list (list[Tensor]): Label weights of all
|
||||||
|
images.Each with shape [num_queries, ].
|
||||||
|
- mask_targets_list (list[Tensor]): Mask targets of all images.
|
||||||
|
Each with shape [num_queries, h, w].
|
||||||
|
- mask_weights_list (list[Tensor]): Mask weights of all images.
|
||||||
|
Each with shape [num_queries, ].
|
||||||
|
- num_total_pos (int): Number of positive samples in all
|
||||||
|
images.
|
||||||
|
- num_total_neg (int): Number of negative samples in all
|
||||||
|
images.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
labels_list,
|
||||||
|
label_weights_list,
|
||||||
|
mask_targets_list,
|
||||||
|
mask_weights_list,
|
||||||
|
pos_inds_list,
|
||||||
|
neg_inds_list,
|
||||||
|
) = multi_apply(
|
||||||
|
self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
|
||||||
|
)
|
||||||
|
|
||||||
|
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
||||||
|
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
||||||
|
return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
|
||||||
|
|
||||||
|
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
|
||||||
|
"""Compute classification and mask targets for one image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls_score (Tensor): Mask score logits from a single decoder layer
|
||||||
|
for one image. Shape (num_queries, cls_out_channels).
|
||||||
|
mask_pred (Tensor): Mask logits for a single decoder layer for one
|
||||||
|
image. Shape (num_queries, h, w).
|
||||||
|
gt_labels (Tensor): Ground truth class indices for one image with
|
||||||
|
shape (num_gts, ).
|
||||||
|
gt_masks (Tensor): Ground truth mask for each image, each with
|
||||||
|
shape (num_gts, h, w).
|
||||||
|
img_metas (dict): Image informtation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor]: A tuple containing the following for one image.
|
||||||
|
|
||||||
|
- labels (Tensor): Labels of each image. \
|
||||||
|
shape (num_queries, ).
|
||||||
|
- label_weights (Tensor): Label weights of each image. \
|
||||||
|
shape (num_queries, ).
|
||||||
|
- mask_targets (Tensor): Mask targets of each image. \
|
||||||
|
shape (num_queries, h, w).
|
||||||
|
- mask_weights (Tensor): Mask weights of each image. \
|
||||||
|
shape (num_queries, ).
|
||||||
|
- pos_inds (Tensor): Sampled positive indices for each \
|
||||||
|
image.
|
||||||
|
- neg_inds (Tensor): Sampled negative indices for each \
|
||||||
|
image.
|
||||||
|
"""
|
||||||
|
# sample points
|
||||||
|
num_queries = cls_score.shape[0]
|
||||||
|
num_gts = gt_labels.shape[0]
|
||||||
|
|
||||||
|
point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
|
||||||
|
# shape (num_queries, num_points)
|
||||||
|
mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
|
||||||
|
# shape (num_gts, num_points)
|
||||||
|
gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
|
||||||
|
|
||||||
|
# assign and sample
|
||||||
|
assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
|
||||||
|
sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
|
||||||
|
pos_inds = sampling_result.pos_inds
|
||||||
|
neg_inds = sampling_result.neg_inds
|
||||||
|
|
||||||
|
# label target
|
||||||
|
labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
|
||||||
|
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
|
||||||
|
label_weights = gt_labels.new_ones((self.num_queries,))
|
||||||
|
|
||||||
|
# mask target
|
||||||
|
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
|
||||||
|
mask_weights = mask_pred.new_zeros((self.num_queries,))
|
||||||
|
mask_weights[pos_inds] = 1.0
|
||||||
|
|
||||||
|
return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
|
||||||
|
|
||||||
|
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
||||||
|
"""Loss function for outputs from a single decoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls_scores (Tensor): Mask score logits from a single decoder layer
|
||||||
|
for all images. Shape (batch_size, num_queries,
|
||||||
|
cls_out_channels). Note `cls_out_channels` should includes
|
||||||
|
background.
|
||||||
|
mask_preds (Tensor): Mask logits for a pixel decoder for all
|
||||||
|
images. Shape (batch_size, num_queries, h, w).
|
||||||
|
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
||||||
|
image, each with shape (num_gts, ).
|
||||||
|
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
||||||
|
each with shape (num_gts, h, w).
|
||||||
|
img_metas (list[dict]): List of image meta information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor]: Loss components for outputs from a single \
|
||||||
|
decoder layer.
|
||||||
|
"""
|
||||||
|
num_imgs = cls_scores.size(0)
|
||||||
|
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
|
||||||
|
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
|
||||||
|
(
|
||||||
|
labels_list,
|
||||||
|
label_weights_list,
|
||||||
|
mask_targets_list,
|
||||||
|
mask_weights_list,
|
||||||
|
num_total_pos,
|
||||||
|
num_total_neg,
|
||||||
|
) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
|
||||||
|
# shape (batch_size, num_queries)
|
||||||
|
labels = torch.stack(labels_list, dim=0)
|
||||||
|
# shape (batch_size, num_queries)
|
||||||
|
label_weights = torch.stack(label_weights_list, dim=0)
|
||||||
|
# shape (num_total_gts, h, w)
|
||||||
|
mask_targets = torch.cat(mask_targets_list, dim=0)
|
||||||
|
# shape (batch_size, num_queries)
|
||||||
|
mask_weights = torch.stack(mask_weights_list, dim=0)
|
||||||
|
|
||||||
|
# classfication loss
|
||||||
|
# shape (batch_size * num_queries, )
|
||||||
|
cls_scores = cls_scores.flatten(0, 1)
|
||||||
|
labels = labels.flatten(0, 1)
|
||||||
|
label_weights = label_weights.flatten(0, 1)
|
||||||
|
|
||||||
|
class_weight = cls_scores.new_tensor(self.class_weight)
|
||||||
|
loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
|
||||||
|
|
||||||
|
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
|
||||||
|
num_total_masks = max(num_total_masks, 1)
|
||||||
|
|
||||||
|
# extract positive ones
|
||||||
|
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
||||||
|
mask_preds = mask_preds[mask_weights > 0]
|
||||||
|
|
||||||
|
if mask_targets.shape[0] == 0:
|
||||||
|
# zero match
|
||||||
|
loss_dice = mask_preds.sum()
|
||||||
|
loss_mask = mask_preds.sum()
|
||||||
|
return loss_cls, loss_mask, loss_dice
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
points_coords = get_uncertain_point_coords_with_randomness(
|
||||||
|
mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
|
||||||
|
)
|
||||||
|
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
|
||||||
|
mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
|
||||||
|
# shape (num_queries, h, w) -> (num_queries, num_points)
|
||||||
|
mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
||||||
|
|
||||||
|
# dice loss
|
||||||
|
loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
|
||||||
|
|
||||||
|
# mask loss
|
||||||
|
# shape (num_queries, num_points) -> (num_queries * num_points, )
|
||||||
|
mask_point_preds = mask_point_preds.reshape(-1, 1)
|
||||||
|
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
|
||||||
|
mask_point_targets = mask_point_targets.reshape(-1)
|
||||||
|
loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
|
||||||
|
|
||||||
|
return loss_cls, loss_mask, loss_dice
|
||||||
|
|
||||||
|
@force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
|
||||||
|
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
||||||
|
"""Loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_cls_scores (Tensor): Classification scores for all decoder
|
||||||
|
layers with shape [num_decoder, batch_size, num_queries,
|
||||||
|
cls_out_channels].
|
||||||
|
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
||||||
|
shape [num_decoder, batch_size, num_queries, h, w].
|
||||||
|
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
||||||
|
image with shape (n, ). n is the sum of number of stuff type
|
||||||
|
and number of instance in a image.
|
||||||
|
gt_masks_list (list[Tensor]): Ground truth mask for each image with
|
||||||
|
shape (n, h, w).
|
||||||
|
img_metas (list[dict]): List of image meta information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Tensor]: A dictionary of loss components.
|
||||||
|
"""
|
||||||
|
num_dec_layers = len(all_cls_scores)
|
||||||
|
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
||||||
|
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
|
||||||
|
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
||||||
|
losses_cls, losses_mask, losses_dice = multi_apply(
|
||||||
|
self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_dict = dict()
|
||||||
|
# loss from the last decoder layer
|
||||||
|
loss_dict["loss_cls"] = losses_cls[-1]
|
||||||
|
loss_dict["loss_mask"] = losses_mask[-1]
|
||||||
|
loss_dict["loss_dice"] = losses_dice[-1]
|
||||||
|
# loss from other decoder layers
|
||||||
|
num_dec_layer = 0
|
||||||
|
for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
|
||||||
|
loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
|
||||||
|
loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
|
||||||
|
loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
|
||||||
|
num_dec_layer += 1
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
|
||||||
|
"""Forward for head part which is called after every decoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_out (Tensor): in shape (num_queries, batch_size, c).
|
||||||
|
mask_feature (Tensor): in shape (batch_size, c, h, w).
|
||||||
|
attn_mask_target_size (tuple[int, int]): target attention
|
||||||
|
mask size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple contain three elements.
|
||||||
|
|
||||||
|
- cls_pred (Tensor): Classification scores in shape \
|
||||||
|
(batch_size, num_queries, cls_out_channels). \
|
||||||
|
Note `cls_out_channels` should includes background.
|
||||||
|
- mask_pred (Tensor): Mask scores in shape \
|
||||||
|
(batch_size, num_queries,h, w).
|
||||||
|
- attn_mask (Tensor): Attention mask in shape \
|
||||||
|
(batch_size * num_heads, num_queries, h, w).
|
||||||
|
"""
|
||||||
|
decoder_out = self.transformer_decoder.post_norm(decoder_out)
|
||||||
|
decoder_out = decoder_out.transpose(0, 1)
|
||||||
|
# shape (num_queries, batch_size, c)
|
||||||
|
cls_pred = self.cls_embed(decoder_out)
|
||||||
|
# shape (num_queries, batch_size, c)
|
||||||
|
mask_embed = self.mask_embed(decoder_out)
|
||||||
|
# shape (num_queries, batch_size, h, w)
|
||||||
|
mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
|
||||||
|
attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
|
||||||
|
# shape (num_queries, batch_size, h, w) ->
|
||||||
|
# (batch_size * num_head, num_queries, h, w)
|
||||||
|
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
|
||||||
|
attn_mask = attn_mask.sigmoid() < 0.5
|
||||||
|
attn_mask = attn_mask.detach()
|
||||||
|
|
||||||
|
return cls_pred, mask_pred, attn_mask
|
||||||
|
|
||||||
|
def forward(self, feats, img_metas):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats (list[Tensor]): Multi scale Features from the
|
||||||
|
upstream network, each is a 4D-tensor.
|
||||||
|
img_metas (list[dict]): List of image information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple contains two elements.
|
||||||
|
|
||||||
|
- cls_pred_list (list[Tensor)]: Classification logits \
|
||||||
|
for each decoder layer. Each is a 3D-tensor with shape \
|
||||||
|
(batch_size, num_queries, cls_out_channels). \
|
||||||
|
Note `cls_out_channels` should includes background.
|
||||||
|
- mask_pred_list (list[Tensor]): Mask logits for each \
|
||||||
|
decoder layer. Each with shape (batch_size, num_queries, \
|
||||||
|
h, w).
|
||||||
|
"""
|
||||||
|
batch_size = len(img_metas)
|
||||||
|
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
|
||||||
|
# multi_scale_memorys (from low resolution to high resolution)
|
||||||
|
decoder_inputs = []
|
||||||
|
decoder_positional_encodings = []
|
||||||
|
for i in range(self.num_transformer_feat_level):
|
||||||
|
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
|
||||||
|
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
||||||
|
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
|
||||||
|
level_embed = self.level_embed.weight[i].view(1, 1, -1)
|
||||||
|
decoder_input = decoder_input + level_embed
|
||||||
|
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
||||||
|
mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
|
||||||
|
decoder_positional_encoding = self.decoder_positional_encoding(mask)
|
||||||
|
decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
|
||||||
|
decoder_inputs.append(decoder_input)
|
||||||
|
decoder_positional_encodings.append(decoder_positional_encoding)
|
||||||
|
# shape (num_queries, c) -> (num_queries, batch_size, c)
|
||||||
|
query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
||||||
|
query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
||||||
|
|
||||||
|
cls_pred_list = []
|
||||||
|
mask_pred_list = []
|
||||||
|
cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
|
||||||
|
cls_pred_list.append(cls_pred)
|
||||||
|
mask_pred_list.append(mask_pred)
|
||||||
|
|
||||||
|
for i in range(self.num_transformer_decoder_layers):
|
||||||
|
level_idx = i % self.num_transformer_feat_level
|
||||||
|
# if a mask is all True(all background), then set it all False.
|
||||||
|
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
||||||
|
|
||||||
|
# cross_attn + self_attn
|
||||||
|
layer = self.transformer_decoder.layers[i]
|
||||||
|
attn_masks = [attn_mask, None]
|
||||||
|
query_feat = layer(
|
||||||
|
query=query_feat,
|
||||||
|
key=decoder_inputs[level_idx],
|
||||||
|
value=decoder_inputs[level_idx],
|
||||||
|
query_pos=query_embed,
|
||||||
|
key_pos=decoder_positional_encodings[level_idx],
|
||||||
|
attn_masks=attn_masks,
|
||||||
|
query_key_padding_mask=None,
|
||||||
|
# here we do not apply masking on padded region
|
||||||
|
key_padding_mask=None,
|
||||||
|
)
|
||||||
|
cls_pred, mask_pred, attn_mask = self.forward_head(
|
||||||
|
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_pred_list.append(cls_pred)
|
||||||
|
mask_pred_list.append(mask_pred)
|
||||||
|
|
||||||
|
return cls_pred_list, mask_pred_list
|
||||||
|
|
||||||
|
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
|
||||||
|
"""Forward function for training mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (list[Tensor]): Multi-level features from the upstream network,
|
||||||
|
each is a 4D-tensor.
|
||||||
|
img_metas (list[Dict]): List of image information.
|
||||||
|
gt_semantic_seg (list[tensor]):Each element is the ground truth
|
||||||
|
of semantic segmentation with the shape (N, H, W).
|
||||||
|
train_cfg (dict): The training config, which not been used in
|
||||||
|
maskformer.
|
||||||
|
gt_labels (list[Tensor]): Each element is ground truth labels of
|
||||||
|
each box, shape (num_gts,).
|
||||||
|
gt_masks (list[BitmapMasks]): Each element is masks of instances
|
||||||
|
of a image, shape (num_gts, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
losses (dict[str, Tensor]): a dictionary of loss components
|
||||||
|
"""
|
||||||
|
|
||||||
|
# forward
|
||||||
|
all_cls_scores, all_mask_preds = self(x, img_metas)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def forward_test(self, inputs, img_metas, test_cfg):
|
||||||
|
"""Test segment without test-time aumengtation.
|
||||||
|
|
||||||
|
Only the output of last decoder layers was used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): Multi-level features from the
|
||||||
|
upstream network, each is a 4D-tensor.
|
||||||
|
img_metas (list[dict]): List of image information.
|
||||||
|
test_cfg (dict): Testing config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
seg_mask (Tensor): Predicted semantic segmentation logits.
|
||||||
|
"""
|
||||||
|
all_cls_scores, all_mask_preds = self(inputs, img_metas)
|
||||||
|
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
|
||||||
|
ori_h, ori_w, _ = img_metas[0]["ori_shape"]
|
||||||
|
|
||||||
|
# semantic inference
|
||||||
|
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
|
||||||
|
mask_pred = mask_pred.sigmoid()
|
||||||
|
seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
|
||||||
|
return seg_mask
|
@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy
|
||||||
|
from .dice_loss import DiceLoss
|
||||||
|
from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost
|
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmseg.models.builder import LOSSES
|
||||||
|
from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(
|
||||||
|
pred,
|
||||||
|
label,
|
||||||
|
weight=None,
|
||||||
|
class_weight=None,
|
||||||
|
reduction="mean",
|
||||||
|
avg_factor=None,
|
||||||
|
ignore_index=-100,
|
||||||
|
avg_non_ignore=False,
|
||||||
|
):
|
||||||
|
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||||
|
label (torch.Tensor): The learning label of the prediction.
|
||||||
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||||
|
Default: None.
|
||||||
|
class_weight (list[float], optional): The weight for each class.
|
||||||
|
Default: None.
|
||||||
|
reduction (str, optional): The method used to reduce the loss.
|
||||||
|
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Default: None.
|
||||||
|
ignore_index (int): Specifies a target value that is ignored and
|
||||||
|
does not contribute to the input gradients. When
|
||||||
|
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||||
|
``''mean''``, the loss is averaged over non-ignored targets.
|
||||||
|
Defaults: -100.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# class_weight is a manual rescaling weight given to each class.
|
||||||
|
# If given, has to be a Tensor of size C element-wise losses
|
||||||
|
loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)
|
||||||
|
|
||||||
|
# apply weights and do the reduction
|
||||||
|
# average loss over non-ignored elements
|
||||||
|
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||||
|
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||||
|
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
|
||||||
|
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
||||||
|
if weight is not None:
|
||||||
|
weight = weight.float()
|
||||||
|
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||||
|
"""Expand onehot labels to match the size of prediction."""
|
||||||
|
bin_labels = labels.new_zeros(target_shape)
|
||||||
|
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||||||
|
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||||||
|
|
||||||
|
if inds[0].numel() > 0:
|
||||||
|
if labels.dim() == 3:
|
||||||
|
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||||||
|
else:
|
||||||
|
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||||
|
|
||||||
|
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||||
|
|
||||||
|
if label_weights is None:
|
||||||
|
bin_label_weights = valid_mask
|
||||||
|
else:
|
||||||
|
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||||
|
bin_label_weights = bin_label_weights * valid_mask
|
||||||
|
|
||||||
|
return bin_labels, bin_label_weights, valid_mask
|
||||||
|
|
||||||
|
|
||||||
|
def binary_cross_entropy(
|
||||||
|
pred,
|
||||||
|
label,
|
||||||
|
weight=None,
|
||||||
|
reduction="mean",
|
||||||
|
avg_factor=None,
|
||||||
|
class_weight=None,
|
||||||
|
ignore_index=-100,
|
||||||
|
avg_non_ignore=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Calculate the binary CrossEntropy loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||||
|
label (torch.Tensor): The learning label of the prediction.
|
||||||
|
Note: In bce loss, label < 0 is invalid.
|
||||||
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||||
|
reduction (str, optional): The method used to reduce the loss.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Defaults to None.
|
||||||
|
class_weight (list[float], optional): The weight for each class.
|
||||||
|
ignore_index (int): The label index to be ignored. Default: -100.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The calculated loss
|
||||||
|
"""
|
||||||
|
if pred.size(1) == 1:
|
||||||
|
# For binary class segmentation, the shape of pred is
|
||||||
|
# [N, 1, H, W] and that of label is [N, H, W].
|
||||||
|
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
|
||||||
|
pred = pred.squeeze()
|
||||||
|
if pred.dim() != label.dim():
|
||||||
|
assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), (
|
||||||
|
"Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported"
|
||||||
|
)
|
||||||
|
# `weight` returned from `_expand_onehot_labels`
|
||||||
|
# has been treated for valid (non-ignore) pixels
|
||||||
|
label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index)
|
||||||
|
else:
|
||||||
|
# should mask out the ignored elements
|
||||||
|
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||||
|
if weight is not None:
|
||||||
|
weight = weight * valid_mask
|
||||||
|
else:
|
||||||
|
weight = valid_mask
|
||||||
|
# average loss over non-ignored and valid elements
|
||||||
|
if reduction == "mean" and avg_factor is None and avg_non_ignore:
|
||||||
|
avg_factor = valid_mask.sum().item()
|
||||||
|
|
||||||
|
loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
|
||||||
|
# do the reduction for the weighted loss
|
||||||
|
loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def mask_cross_entropy(
|
||||||
|
pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs
|
||||||
|
):
|
||||||
|
"""Calculate the CrossEntropy loss for masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||||
|
of classes.
|
||||||
|
target (torch.Tensor): The learning label of the prediction.
|
||||||
|
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
||||||
|
corresponding object. This will be used to select the mask in the
|
||||||
|
of the class which the object belongs to when the mask prediction
|
||||||
|
if not class-agnostic.
|
||||||
|
reduction (str, optional): The method used to reduce the loss.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Defaults to None.
|
||||||
|
class_weight (list[float], optional): The weight for each class.
|
||||||
|
ignore_index (None): Placeholder, to be consistent with other loss.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The calculated loss
|
||||||
|
"""
|
||||||
|
assert ignore_index is None, "BCE loss does not support ignore_index"
|
||||||
|
assert reduction == "mean" and avg_factor is None
|
||||||
|
num_rois = pred.size()[0]
|
||||||
|
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
||||||
|
pred_slice = pred[inds, label].squeeze(1)
|
||||||
|
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module(force=True)
|
||||||
|
class CrossEntropyLoss(nn.Module):
|
||||||
|
"""CrossEntropyLoss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||||
|
of softmax. Defaults to False.
|
||||||
|
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
||||||
|
Defaults to False.
|
||||||
|
reduction (str, optional): . Defaults to 'mean'.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||||
|
str format, read them from a file. Defaults to None.
|
||||||
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||||
|
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||||
|
item to be included into the backward graph, `loss_` must be the
|
||||||
|
prefix of the name. Defaults to 'loss_ce'.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_sigmoid=False,
|
||||||
|
use_mask=False,
|
||||||
|
reduction="mean",
|
||||||
|
class_weight=None,
|
||||||
|
loss_weight=1.0,
|
||||||
|
loss_name="loss_ce",
|
||||||
|
avg_non_ignore=False,
|
||||||
|
):
|
||||||
|
super(CrossEntropyLoss, self).__init__()
|
||||||
|
assert (use_sigmoid is False) or (use_mask is False)
|
||||||
|
self.use_sigmoid = use_sigmoid
|
||||||
|
self.use_mask = use_mask
|
||||||
|
self.reduction = reduction
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.class_weight = get_class_weight(class_weight)
|
||||||
|
self.avg_non_ignore = avg_non_ignore
|
||||||
|
if not self.avg_non_ignore and self.reduction == "mean":
|
||||||
|
warnings.warn(
|
||||||
|
"Default ``avg_non_ignore`` is False, if you would like to "
|
||||||
|
"ignore the certain label and average loss over non-ignore "
|
||||||
|
"labels, which is the same with PyTorch official "
|
||||||
|
"cross_entropy, set ``avg_non_ignore=True``."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_sigmoid:
|
||||||
|
self.cls_criterion = binary_cross_entropy
|
||||||
|
elif self.use_mask:
|
||||||
|
self.cls_criterion = mask_cross_entropy
|
||||||
|
else:
|
||||||
|
self.cls_criterion = cross_entropy
|
||||||
|
self._loss_name = loss_name
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
"""Extra repr."""
|
||||||
|
s = f"avg_non_ignore={self.avg_non_ignore}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs
|
||||||
|
):
|
||||||
|
"""Forward function."""
|
||||||
|
assert reduction_override in (None, "none", "mean", "sum")
|
||||||
|
reduction = reduction_override if reduction_override else self.reduction
|
||||||
|
if self.class_weight is not None:
|
||||||
|
class_weight = cls_score.new_tensor(self.class_weight)
|
||||||
|
else:
|
||||||
|
class_weight = None
|
||||||
|
# Note: for BCE loss, label < 0 is invalid.
|
||||||
|
loss_cls = self.loss_weight * self.cls_criterion(
|
||||||
|
cls_score,
|
||||||
|
label,
|
||||||
|
weight,
|
||||||
|
class_weight=class_weight,
|
||||||
|
reduction=reduction,
|
||||||
|
avg_factor=avg_factor,
|
||||||
|
avg_non_ignore=self.avg_non_ignore,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return loss_cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_name(self):
|
||||||
|
"""Loss Name.
|
||||||
|
|
||||||
|
This function must be implemented and will return the name of this
|
||||||
|
loss function. This name will be used to combine different loss items
|
||||||
|
by simple sum operation. In addition, if you want this loss item to be
|
||||||
|
included into the backward graph, `loss_` must be the prefix of the
|
||||||
|
name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of this loss item.
|
||||||
|
"""
|
||||||
|
return self._loss_name
|
@ -0,0 +1,153 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the Apache License, Version 2.0
|
||||||
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmseg.models.builder import LOSSES
|
||||||
|
from mmseg.models.losses.utils import weight_reduce_loss
|
||||||
|
|
||||||
|
|
||||||
|
def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
|
||||||
|
"""Calculate dice loss, which is proposed in
|
||||||
|
`V-Net: Fully Convolutional Neural Networks for Volumetric
|
||||||
|
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||||
|
target (torch.Tensor): The learning label of the prediction,
|
||||||
|
shape (n, *), same shape of pred.
|
||||||
|
weight (torch.Tensor, optional): The weight of loss for each
|
||||||
|
prediction, has a shape (n,). Defaults to None.
|
||||||
|
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||||
|
reduction (str, optional): The method used to reduce the loss into
|
||||||
|
a scalar. Defaults to 'mean'.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input = pred.flatten(1)
|
||||||
|
target = target.flatten(1).float()
|
||||||
|
|
||||||
|
a = torch.sum(input * target, 1)
|
||||||
|
b = torch.sum(input * input, 1) + eps
|
||||||
|
c = torch.sum(target * target, 1) + eps
|
||||||
|
d = (2 * a) / (b + c)
|
||||||
|
loss = 1 - d
|
||||||
|
if weight is not None:
|
||||||
|
assert weight.ndim == loss.ndim
|
||||||
|
assert len(weight) == len(pred)
|
||||||
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
|
||||||
|
"""Calculate naive dice loss, the coefficient in the denominator is the
|
||||||
|
first power instead of the second power.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||||
|
target (torch.Tensor): The learning label of the prediction,
|
||||||
|
shape (n, *), same shape of pred.
|
||||||
|
weight (torch.Tensor, optional): The weight of loss for each
|
||||||
|
prediction, has a shape (n,). Defaults to None.
|
||||||
|
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||||
|
reduction (str, optional): The method used to reduce the loss into
|
||||||
|
a scalar. Defaults to 'mean'.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Defaults to None.
|
||||||
|
"""
|
||||||
|
input = pred.flatten(1)
|
||||||
|
target = target.flatten(1).float()
|
||||||
|
|
||||||
|
a = torch.sum(input * target, 1)
|
||||||
|
b = torch.sum(input, 1)
|
||||||
|
c = torch.sum(target, 1)
|
||||||
|
d = (2 * a + eps) / (b + c + eps)
|
||||||
|
loss = 1 - d
|
||||||
|
if weight is not None:
|
||||||
|
assert weight.ndim == loss.ndim
|
||||||
|
assert len(weight) == len(pred)
|
||||||
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module(force=True)
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3):
|
||||||
|
"""Dice Loss, there are two forms of dice loss is supported:
|
||||||
|
|
||||||
|
- the one proposed in `V-Net: Fully Convolutional Neural
|
||||||
|
Networks for Volumetric Medical Image Segmentation
|
||||||
|
<https://arxiv.org/abs/1606.04797>`_.
|
||||||
|
- the dice loss in which the power of the number in the
|
||||||
|
denominator is the first power instead of the second
|
||||||
|
power.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_sigmoid (bool, optional): Whether to the prediction is
|
||||||
|
used for sigmoid or softmax. Defaults to True.
|
||||||
|
activate (bool): Whether to activate the predictions inside,
|
||||||
|
this will disable the inside sigmoid operation.
|
||||||
|
Defaults to True.
|
||||||
|
reduction (str, optional): The method used
|
||||||
|
to reduce the loss. Options are "none",
|
||||||
|
"mean" and "sum". Defaults to 'mean'.
|
||||||
|
naive_dice (bool, optional): If false, use the dice
|
||||||
|
loss defined in the V-Net paper, otherwise, use the
|
||||||
|
naive dice loss in which the power of the number in the
|
||||||
|
denominator is the first power instead of the second
|
||||||
|
power.Defaults to False.
|
||||||
|
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||||
|
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
self.use_sigmoid = use_sigmoid
|
||||||
|
self.reduction = reduction
|
||||||
|
self.naive_dice = naive_dice
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.eps = eps
|
||||||
|
self.activate = activate
|
||||||
|
|
||||||
|
def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction, has a shape (n, *).
|
||||||
|
target (torch.Tensor): The label of the prediction,
|
||||||
|
shape (n, *), same shape of pred.
|
||||||
|
weight (torch.Tensor, optional): The weight of loss for each
|
||||||
|
prediction, has a shape (n,). Defaults to None.
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Defaults to None.
|
||||||
|
reduction_override (str, optional): The reduction method used to
|
||||||
|
override the original reduction method of the loss.
|
||||||
|
Options are "none", "mean" and "sum".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The calculated loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert reduction_override in (None, "none", "mean", "sum")
|
||||||
|
reduction = reduction_override if reduction_override else self.reduction
|
||||||
|
|
||||||
|
if self.activate:
|
||||||
|
if self.use_sigmoid:
|
||||||
|
pred = pred.sigmoid()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if self.naive_dice:
|
||||||
|
loss = self.loss_weight * naive_dice_loss(
|
||||||
|
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = self.loss_weight * dice_loss(
|
||||||
|
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user