-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain.py
354 lines (319 loc) · 15.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
Training Codes of LightningDiT together with VA-VAE.
It envolves advanced training methods, sampling methods,
architecture design methods, computation methods. We achieve
state-of-the-art FID 1.35 on ImageNet 256x256.
by Maple (Jingfeng Yao) from HUST-VL
"""
import torch
import torch.distributed as dist
import torch.backends.cuda
import torch.backends.cudnn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import math
import yaml
import json
import numpy as np
import logging
import os
import argparse
from time import time
from glob import glob
from copy import deepcopy
from collections import OrderedDict
from PIL import Image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from models.lightningdit import LightningDiT_models
from transport import create_transport, Sampler
from accelerate import Accelerator
from datasets.img_latent_dataset import ImgLatentDataset
def do_train(train_config, accelerator):
"""
Trains a LightningDiT.
"""
# Setup accelerator:
device = accelerator.device
# Setup an experiment folder:
if accelerator.is_main_process:
os.makedirs(train_config['train']['output_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
experiment_index = len(glob(f"{train_config['train']['output_dir']}/*"))
model_string_name = train_config['model']['model_type'].replace("/", "-")
if train_config['train']['exp_name'] is None:
exp_name = f'{experiment_index:03d}-{model_string_name}'
else:
exp_name = train_config['train']['exp_name']
experiment_dir = f"{train_config['train']['output_dir']}/{exp_name}" # Create an experiment folder
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
tensorboard_dir_log = f"tensorboard_logs/{exp_name}"
os.makedirs(tensorboard_dir_log, exist_ok=True)
writer = SummaryWriter(log_dir=tensorboard_dir_log)
# add configs to tensorboard
config_str=json.dumps(train_config, indent=4)
writer.add_text('training configs', config_str, global_step=0)
checkpoint_dir = f"{train_config['train']['output_dir']}/{train_config['train']['exp_name']}/checkpoints"
# get rank
rank = accelerator.local_process_index
# Create model:
if 'downsample_ratio' in train_config['vae']:
downsample_ratio = train_config['vae']['downsample_ratio']
else:
downsample_ratio = 16
assert train_config['data']['image_size'] % downsample_ratio == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = train_config['data']['image_size'] // downsample_ratio
model = LightningDiT_models[train_config['model']['model_type']](
input_size=latent_size,
num_classes=train_config['data']['num_classes'],
use_qknorm=train_config['model']['use_qknorm'],
use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False,
use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False,
use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False,
wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False,
in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4,
use_checkpoint=train_config['model']['use_checkpoint'] if 'use_checkpoint' in train_config['model'] else False,
)
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
# load pretrained model
if 'weight_init' in train_config['train']:
checkpoint = torch.load(train_config['train']['weight_init'], map_location=lambda storage, loc: storage)
# remove the prefix 'module.' from the keys
checkpoint['model'] = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
model = load_weights_with_shape_check(model, checkpoint, rank=rank)
ema = load_weights_with_shape_check(ema, checkpoint, rank=rank)
if accelerator.is_main_process:
logger.info(f"Loaded pretrained model from {train_config['train']['weight_init']}")
requires_grad(ema, False)
model = DDP(model.to(device), device_ids=[rank])
transport = create_transport(
train_config['transport']['path_type'],
train_config['transport']['prediction'],
train_config['transport']['loss_weight'],
train_config['transport']['train_eps'],
train_config['transport']['sample_eps'],
use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False,
use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False,
) # default: velocity;
if accelerator.is_main_process:
logger.info(f"LightningDiT Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
logger.info(f"Optimizer: AdamW, lr={train_config['optimizer']['lr']}, beta2={train_config['optimizer']['beta2']}")
logger.info(f'Use lognorm sampling: {train_config["transport"]["use_lognorm"]}')
logger.info(f'Use cosine loss: {train_config["transport"]["use_cosine_loss"]}')
opt = torch.optim.AdamW(model.parameters(), lr=train_config['optimizer']['lr'], weight_decay=0, betas=(0.9, train_config['optimizer']['beta2']))
# Setup data
dataset = ImgLatentDataset(
data_dir=train_config['data']['data_path'],
latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
)
batch_size_per_gpu = int(np.round(train_config['train']['global_batch_size'] / accelerator.num_processes))
global_batch_size = batch_size_per_gpu * accelerator.num_processes
loader = DataLoader(
dataset,
batch_size=batch_size_per_gpu,
shuffle=True,
num_workers=train_config['data']['num_workers'],
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(dataset):,} images {train_config['data']['data_path']}")
logger.info(f"Batch size {batch_size_per_gpu} per gpu, with {global_batch_size} global batch size")
if 'valid_path' in train_config['data']:
valid_dataset = ImgLatentDataset(
data_dir=train_config['data']['valid_path'],
latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size_per_gpu,
shuffle=True,
num_workers=train_config['data']['num_workers'],
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Validation Dataset contains {len(valid_dataset):,} images {train_config['data']['valid_path']}")
# Prepare models for training:
update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode
train_config['train']['resume'] = train_config['train']['resume'] if 'resume' in train_config['train'] else False
if train_config['train']['resume']:
# check if the checkpoint exists
checkpoint_files = glob(f"{checkpoint_dir}/*.pt")
if checkpoint_files:
checkpoint_files.sort(key=lambda x: os.path.getsize(x))
latest_checkpoint = checkpoint_files[-1]
checkpoint = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['model'])
# opt.load_state_dict(checkpoint['opt'])
ema.load_state_dict(checkpoint['ema'])
train_steps = int(latest_checkpoint.split('/')[-1].split('.')[0])
if accelerator.is_main_process:
logger.info(f"Resuming training from checkpoint: {latest_checkpoint}")
else:
if accelerator.is_main_process:
logger.info("No checkpoint found. Starting training from scratch.")
model, opt, loader = accelerator.prepare(model, opt, loader)
# Variables for monitoring/logging purposes:
if not train_config['train']['resume']:
train_steps = 0
log_steps = 0
running_loss = 0
start_time = time()
use_checkpoint = train_config['train']['use_checkpoint'] if 'use_checkpoint' in train_config['train'] else True
if accelerator.is_main_process:
logger.info(f"Using checkpointing: {use_checkpoint}")
while True:
for x, y in loader:
if accelerator.mixed_precision == 'no':
x = x.to(device, dtype=torch.float32)
y = y
else:
x = x.to(device)
y = y.to(device)
model_kwargs = dict(y=y)
loss_dict = transport.training_losses(model, x, model_kwargs)
if 'cos_loss' in loss_dict:
mse_loss = loss_dict["loss"].mean()
loss = loss_dict["cos_loss"].mean() + mse_loss
else:
loss = loss_dict["loss"].mean()
opt.zero_grad()
accelerator.backward(loss)
if 'max_grad_norm' in train_config['optimizer']:
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), train_config['optimizer']['max_grad_norm'])
opt.step()
update_ema(ema, model.module)
# Log loss values:
if 'cos_loss' in loss_dict:
running_loss += mse_loss.item()
else:
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % train_config['train']['log_every'] == 0:
# Measure training speed:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = log_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / dist.get_world_size()
if accelerator.is_main_process:
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
writer.add_scalar('Loss/train', avg_loss, train_steps)
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time()
# Save checkpoint:
if train_steps % train_config['train']['ckpt_every'] == 0 and train_steps > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": opt.state_dict(),
"config": train_config,
}
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
torch.save(checkpoint, checkpoint_path)
if accelerator.is_main_process:
logger.info(f"Saved checkpoint to {checkpoint_path}")
dist.barrier()
# Evaluate on validation set
if 'valid_path' in train_config['data']:
if accelerator.is_main_process:
logger.info(f"Start evaluating at step {train_steps}")
val_loss = evaluate(model, valid_loader, device, transport, (0.0, 1.0))
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
val_loss = val_loss.item() / dist.get_world_size()
if accelerator.is_main_process:
logger.info(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar('Loss/validation', val_loss, train_steps)
model.train()
if train_steps >= train_config['train']['max_steps']:
break
if train_steps >= train_config['train']['max_steps']:
break
if accelerator.is_main_process:
logger.info("Done!")
return accelerator
def load_weights_with_shape_check(model, checkpoint, rank=0):
model_state_dict = model.state_dict()
# check shape and load weights
for name, param in checkpoint['model'].items():
if name in model_state_dict:
if param.shape == model_state_dict[name].shape:
model_state_dict[name].copy_(param)
elif name == 'x_embedder.proj.weight':
# special case for x_embedder.proj.weight
# the pretrained model is trained with 256x256 images
# we can load the weights by resizing the weights
# and keep the first 3 channels the same
weight = torch.zeros_like(model_state_dict[name])
weight[:, :16] = param[:, :16]
model_state_dict[name] = weight
else:
if rank == 0:
print(f"Skipping loading parameter '{name}' due to shape mismatch: "
f"checkpoint shape {param.shape}, model shape {model_state_dict[name].shape}")
else:
if rank == 0:
print(f"Parameter '{name}' not found in model, skipping.")
# load state dict
model.load_state_dict(model_state_dict, strict=False)
return model
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def load_config(config_path):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
return config
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
if __name__ == "__main__":
# read config
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/debug.yaml')
args = parser.parse_args()
accelerator = Accelerator()
train_config = load_config(args.config)
do_train(train_config, accelerator)