forked from prs-eth/Marigold-DC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmarigold_dc.py
670 lines (600 loc) · 29.9 KB
/
marigold_dc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
from typing import Callable, cast
import diffusers
import torch
from diffusers import MarigoldDepthPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler, LCMScheduler
from torch.optim import SGD, Adadelta, Adagrad, Adam, Optimizer
from transformers import CLIPTextModel, CLIPTokenizer
import utils
diffusers.utils.logging.disable_progress_bar()
MARIGOLD_CKPT_ORIGINAL = "prs-eth/marigold-v1-0"
MARIGOLD_CKPT_LCM = "prs-eth/marigold-lcm-v1-0"
VAE_CKPT_LIGHT = "madebyollin/taesd"
SUPPORTED_LOSS_FUNCS = ["l1", "l2", "edge", "smooth"]
EPSILON = 1e-7
def get_projection_fn(projection: str) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Returns the appropriate logarithmic function based on the specified projection method.
This function is used to transform depth values into logarithmic space, which can
improve accuracy for scenes with large depth ranges by giving more precision to
closer objects and less precision to distant objects.
Args:
projection (str): The projection method to use. Supported values are:
- "log": Natural logarithm (base e)
- "log10": Base-10 logarithm
- "linear": Identity function (no logarithmic transformation)
Returns:
Callable[[torch.Tensor], torch.Tensor]: A function that applies the specified
logarithmic transformation to a tensor of depth values.
Raises:
ValueError: If an unsupported projection method is provided.
"""
if projection == "log":
return torch.log
elif projection == "log10":
return torch.log10
elif projection == "linear":
return lambda x: x
raise ValueError(f"Unknown projection method: {projection}")
def compute_loss(
dense: torch.Tensor,
sparse: torch.Tensor,
mask: torch.Tensor,
loss_funcs: list[str],
image: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Computes a combined loss between dense depth predictions and sparse depth measurements.
Args:
dense: Predicted dense depth map tensor of shape [N, 1, H, W].
sparse: Sparse depth measurements tensor of shape [N, 1, H, W], with zeros at unmeasured points.
mask: Binary mask tensor of shape [N, 1, H, W] indicating valid sparse depth measurements.
loss_funcs: List of loss function names to apply. Supported values are:
- "l1": L1 loss between dense and sparse at measured points
- "l2": L2 loss between dense and sparse at measured points
- "edge": Edge-aware loss that compares gradients with image gradients
- "smooth": Smoothness loss that penalizes depth discontinuities
image: Optional RGB or grayscale image tensor of shape [N, C, H, W] where C is 1 or 3.
Required when using "edge" or "smooth" loss functions.
Returns:
A tensor of shape [N] containing the total loss for each sample in the batch.
Raises:
ValueError: If loss_funcs is empty, contains an unsupported loss function,
or if image is not provided when required for edge/smooth losses.
""" # noqa: E501
if len(loss_funcs) == 0:
raise ValueError("loss_funcs must contain at least one loss function")
total = torch.zeros(dense.shape[0], device=dense.device) # [N]
for loss_func in loss_funcs:
if loss_func == "l1":
# Compute L1 loss per sample using masked operations
l1_loss = torch.abs(dense - sparse)
l1_loss = l1_loss * mask # Apply mask
# Sum over HW dimensions and divide by number of valid points per sample
total += l1_loss.sum(dim=(1, 2, 3)) / mask.sum(dim=(1, 2, 3))
elif loss_func == "l2":
# Compute L2 loss per sample using masked operations
l2_loss = (dense - sparse) ** 2
l2_loss = l2_loss * mask # Apply mask
# Sum over HW dimensions and divide by number of valid points per sample
total += l2_loss.sum(dim=(1, 2, 3)) / mask.sum(dim=(1, 2, 3))
elif loss_func == "edge":
if image is None:
raise ValueError("image must be provided for edge loss")
# Convert to grayscale if needed
num_channels = image.shape[1]
if num_channels == 3:
gray_image = (
0.299 * image[:, 0:1]
+ 0.587 * image[:, 1:2]
+ 0.114 * image[:, 2:3]
) # [N, 1, H, W]
elif num_channels == 1:
gray_image = image
else:
raise ValueError(f"Image must have 1 or 3 channels, got {num_channels}")
# Compute gradients for entire batch at once
grad_pred_x = torch.abs(dense[:, :, :, :-1] - dense[:, :, :, 1:])
grad_pred_y = torch.abs(dense[:, :, :-1, :] - dense[:, :, 1:, :])
grad_gray_x = torch.abs(gray_image[:, :, :, :-1] - gray_image[:, :, :, 1:])
grad_gray_y = torch.abs(gray_image[:, :, :-1, :] - gray_image[:, :, 1:, :])
# Compute edge loss per sample using reduction over spatial dimensions
edge_loss_x = torch.abs(grad_pred_x - grad_gray_x).mean(dim=(1, 2, 3))
edge_loss_y = torch.abs(grad_pred_y - grad_gray_y).mean(dim=(1, 2, 3))
total += edge_loss_x + edge_loss_y
elif loss_func == "smooth":
if image is None:
raise ValueError("image must be provided for smooth loss")
# Compute smoothness loss per sample using reduction over spatial dimensions
loss_h = torch.abs(dense[:, :, :-1, :] - dense[:, :, 1:, :]).mean(
dim=(1, 2, 3)
)
loss_w = torch.abs(dense[:, :, :, :-1] - dense[:, :, :, 1:]).mean(
dim=(1, 2, 3)
)
total += loss_h + loss_w
else:
raise ValueError(f"Unknown loss function: {loss_func}")
return total # Returns tensor of shape [N]
class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
"""
Pipeline for depth completion using the Marigold model.
Takes RGB image and sparse depth as input to produce dense depth maps.
Uses diffusion model to refine depth predictions while preserving sparse
depth constraints through latent optimization and affine transformation.
Supports batch processing with optional temporal consistency.
""" # noqa: E501
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: DDIMScheduler | LCMScheduler,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prediction_type: str | None = None,
scale_invariant: bool | None = True,
shift_invariant: bool | None = True,
default_denoising_steps: int | None = None,
default_processing_resolution: int | None = None,
) -> None:
super().__init__(
unet,
vae,
scheduler,
text_encoder,
tokenizer,
prediction_type,
scale_invariant,
shift_invariant,
default_denoising_steps,
default_processing_resolution,
)
def _affine_to_metric(
self,
dense: torch.Tensor, # [N, 1, H, W]
scale: torch.Tensor, # [N, 1, H, W] or [N, 1, 1, 1]
shift: torch.Tensor, # [N, 1, H, W] or [N, 1, 1, 1]
sparse_range: torch.Tensor, # [N, 1, 1, 1]
sparse_min: torch.Tensor, # [N, 1, 1, 1]
) -> torch.Tensor:
"""
Converts the model's affine-invariant depth representation to metric depth values.
This method applies an affine transformation to convert the normalized depth predictions
from the model's internal representation to actual metric depth values that match
the scale and range of the provided sparse depth measurements. The transformation
uses learned scale and shift parameters to ensure the output depth values are
properly calibrated to the input sparse measurements.
Args:
dense (torch.Tensor): Normalized depth predictions from the model with shape [N, 1, H, W].
scale (torch.Tensor): Learned scaling factor with shape [N, 1, H, W] or [N, 1, 1, 1].
shift (torch.Tensor): Learned shift factor with shape [N, 1, H, W] or [N, 1, 1, 1].
sparse_range (torch.Tensor): Range (max-min) of sparse depth values with shape [N, 1, 1, 1].
sparse_min (torch.Tensor): Minimum value of sparse depth with shape [N, 1, 1, 1].
Returns:
torch.Tensor: Calibrated metric depth values with shape [N, 1, H, W] that match
the scale of the input sparse depth measurements.
""" # noqa: E501
return (scale**2) * sparse_range * dense + (shift**2) * sparse_min
def _latent_to_dense(
self,
latent: torch.Tensor, # [N, 4, EH, EW]
orig_res: tuple[int, int],
padding: tuple[int, int],
affine_invariant: bool = False,
affine_params: tuple[torch.Tensor, torch.Tensor] | None = None,
sparse_range: tuple[torch.Tensor, torch.Tensor] | None = None,
interp_mode: str = "bilinear",
) -> torch.Tensor:
"""
Converts latent representation to dense depth map.
This method decodes the latent representation from the diffusion model into a dense
depth map, applies unpadding, resizes to the original resolution, and optionally
applies affine transformation to match the scale of sparse depth measurements.
Args:
latent (torch.Tensor): Latent representation with shape [N, 4, EH, EW].
orig_res (tuple[int, int]): Original resolution (H, W) to resize to.
padding (tuple[int, int]): Padding values to remove.
affine_invariant (bool, optional): Whether to apply affine transformation.
When True, the output will be transformed to match the scale of sparse depth.
Defaults to False.
affine_params (tuple[torch.Tensor, torch.Tensor] | None, optional):
Scale and shift parameters for affine transformation. Required when
affine_invariant is True. Defaults to None.
sparse_range (tuple[torch.Tensor, torch.Tensor] | None, optional):
Min and max values of sparse depth. Required when affine_invariant is True.
Defaults to None.
interp_mode (str, optional): Interpolation mode for resizing.
Options include "bilinear", "bicubic", etc. Defaults to "bilinear".
Returns:
torch.Tensor: Dense depth map with shape [N, 1, H, W].
""" # noqa: E501
if affine_invariant:
if affine_params is None or sparse_range is None:
raise ValueError(
"scaling and sparse_range must be "
"provided when affine_invariant is True"
)
decoded = self.decode_prediction(latent) # [N, 1, PPH, PPW]
decoded = self.image_processor.unpad_image(decoded, padding)
decoded_resized = self.image_processor.resize_antialias(
decoded, orig_res, interp_mode
) # [N, 1, H, W]
if affine_invariant:
assert affine_params is not None and sparse_range is not None
scale, shift = affine_params
sparse_min, sparse_max = sparse_range
decoded_resized = self._affine_to_metric(
decoded_resized, scale, shift, sparse_max - sparse_min, sparse_min
)
return decoded_resized # [N, 1, H, W]
def __call__(
self,
imgs: torch.Tensor,
sparses: torch.Tensor,
max_depth: float,
min_depth: float = 0.0,
projection: str = "linear", # "linear", "log", "log10"
inv: bool = False,
norm: str = "minmax",
percentile: float = 0.05,
pred_latents_prev: torch.Tensor | None = None,
beta: float = 0.9,
steps: int = 50,
resolution: int = 768,
affine_invariant: bool = True,
opt: str = "adam",
lr: tuple[float, float] | None = None,
kld: bool = False,
kld_weight: float = 0.1,
kld_mode: str = "simple",
interp_mode: str = "bilinear",
loss_funcs: list[str] | None = None,
seed: int = 2024,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Executes depth completion on a batch of RGB images using sparse depth measurements.
This function implements the primary depth completion algorithm through a diffusion-based approach.
It refines depth predictions iteratively by optimizing latent representations via a denoising process
guided by sparse depth measurements.
Args:
imgs (torch.Tensor): A batch of RGB images with dimensions [N, C, H, W].
These are raw images, not normalized to the [0, 1] range.
sparses (torch.Tensor): A batch of sparse depth maps with dimensions [N, 1, H, W].
These maps should have zeros at missing positions and positive values at measurement points.
The depth values are raw and not normalized.
max_depth (float): The maximum depth value for normalization.
min_depth (float, optional): The minimum depth value for normalization. Defaults to 0.0.
projection (str, optional): The method for projecting depth values.
Options include "linear", "log", or "log10". Using "log" or "log10" transforms depth values
to log space before processing, which can enhance accuracy for scenes with large depth ranges.
Defaults to "linear".
inv (bool, optional): Indicates whether to apply inverse projection (1/depth).
When set to True, the model operates with inverse depth (disparity), which can enhance
accuracy for distant objects. Defaults to False.
norm (str, optional): The normalization method for input sparse depth maps.
Options include "const", "minmax", or "percentile". Defaults to "minmax".
percentile (float, optional): The percentile value for determining the depth range.
Used only when norm="percentile". Lower values (e.g., 0.05) exclude outliers
by using the 5th and 95th percentiles. Defaults to 0.05.
pred_latents_prev (torch.Tensor | None, optional): Previous prediction latents
with dimensions [N, 4, EH, EW] from a prior frame or iteration.
This enables temporal consistency when processing video sequences. Defaults to None.
beta (float, optional): The momentum factor for prediction latents between frames.
Must be within the range [0, 1]. Higher values give more weight to new latents,
while lower values retain more information from previous frames. Defaults to 0.9.
steps (int, optional): The number of denoising steps.
Higher values yield better quality but result in slower inference. Defaults to 50.
resolution (int, optional): The resolution for internal processing.
Higher values yield better quality but consume more memory. Defaults to 768.
affine_invariant (bool, optional): Indicates whether to use affine invariant depth completion.
When set to True, the model applies affine transformations to manage arbitrary depth scales
and shifts between the model's internal representation and the input sparse depth.
This allows the model to function with different depth sensors and units without retraining.
The model will automatically estimate the appropriate scale and shift parameters
to align its predictions with the input sparse measurements. Defaults to True.
opt (str, optional): The optimizer to use ("adam", "sgd", "adagrad", or "adadelta").
Defaults to "adam". Note that when opt="adadelta", the learning rate is fixed at 1.0
regardless of the lr parameter.
lr (tuple[float, float] | None, optional): Learning rates for (latent, scaling).
If None, defaults to (0.05, 0.005). For the "adadelta" optimizer, this parameter is ignored.
kld (bool, optional): Indicates whether to apply a KL divergence penalty to
keep prediction latents close to N(0,1). Defaults to False.
kld_weight (float, optional): The weight for the KL divergence penalty.
Used only when kld is True. Defaults to 0.1.
kld_mode (str, optional): The KL divergence mode. Options include:
- "simple": Uses a simplified penalty based on the squared L2 norm of latents.
This is the fastest but least accurate approximation of KL divergence.
- "strict": Computes the proper forward KL divergence between the latent distribution and N(0,1).
This is more accurate but slightly more computationally expensive.
Defaults to "simple".
interp_mode (str, optional): The interpolation mode for resizing.
Options include "bilinear", "bicubic", etc. Defaults to "bilinear".
loss_funcs (list[str] | None, optional): The loss functions to use.
If None, defaults to ["l1", "l2"]. Supported options include
"l1", "l2", "edge", and "smooth". When using "edge" or "smooth",
the RGB image is used to guide depth discontinuities. Defaults to None.
seed (int, optional): The random seed for initializing the diffusion process generator and ensuring reproducibility.
Defaults to 2024.
Returns:
tuple[torch.Tensor, torch.Tensor]:
- A dense depth prediction with dimensions [N, 1, H, W] in metric units (same as input sparse depth)
- Prediction latents with dimensions [N, 4, EH, EW] that can be used for temporal consistency
in subsequent frames
""" # noqa: E501
generator = torch.Generator(device=self.device).manual_seed(seed)
# Create empty text embedding if not created
if self.empty_text_embedding is None:
with torch.no_grad():
text_inputs = self.tokenizer(
"",
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
self.empty_text_embedding: torch.Tensor = self.text_encoder(
text_inputs.input_ids.to(self.device)
)[0]
# Check input shapes
if (
imgs.ndim != 4
or sparses.ndim != 4
or (imgs.shape[0] != sparses.shape[0])
or (imgs.shape[-2:] != sparses.shape[-2:])
):
raise ValueError(
"Shape of image must be [N, C, H, W] and shape of sparse must be "
f"[N, 1, H, W], but got image.shape: "
f"{imgs.shape} and sparse.shape: {sparses.shape}"
)
N, _, H, W = imgs.shape
EH = resolution * H // (8 * max(H, W))
EW = resolution * W // (8 * max(H, W))
if pred_latents_prev is not None:
if pred_latents_prev.ndim != 4 or pred_latents_prev.shape != (N, 4, EH, EW):
raise ValueError(
"Shape of pred_latents_prev must be [N, 4, EH, EW], but got "
f"{pred_latents_prev.shape}"
)
# Check if beta is in (0, 1)
if beta < 0 or beta > 1:
raise ValueError(f"beta must be in [0, 1], but got {beta}")
# Check projection method
if projection not in ["linear", "log", "log10"]:
raise ValueError(f"Unknown projection method: {projection}")
# Check if min_depth > 0 when projection is "log"
if (projection in ["log", "log10"] or inv) and min_depth <= EPSILON:
raise ValueError(
f"min_depth must be > {EPSILON} when "
f"projection is 'log' or 'log10' or inv is True, "
f"but got {min_depth}"
)
# Set learning rates
if lr is None:
lr_latent = 0.05
lr_scaling = 0.005
else:
lr_latent, lr_scaling = lr
# Set loss functions
if loss_funcs is None:
loss_funcs = ["l1", "l2"]
else:
for func in loss_funcs:
if func not in SUPPORTED_LOSS_FUNCS:
raise ValueError(f"Unknown loss function: {func}")
# Tile empty text conditioning
batch_empty_text_embedding = self.empty_text_embedding.repeat(N, 1, 1)
# Create common prediction latents
pred_latents_common = torch.randn(
(N, 4, EH, EW),
device=imgs.device,
dtype=self.dtype,
generator=generator,
) # [N, 4, EH, EW]
# Preprocess input images
imgs_resized, padding, orig_res = self.image_processor.preprocess(
imgs,
processing_resolution=resolution,
device=self.device,
dtype=self.dtype,
) # [N, C, PPH, PPW]
orig_res = cast(tuple[int, int], orig_res)
# Get latent encodings
with torch.no_grad():
img_latents, _ = self.prepare_latents(
imgs_resized, None, generator, 1, N
) # [N, 4, EH, EW], [N, 4, EH, EW]
if pred_latents_prev is not None:
pred_latents = (
beta * pred_latents_common + (1 - beta) * pred_latents_prev
)
else:
pred_latents = pred_latents_common
# Calculate min & max depth values for each sample in the batch
masks = sparses > 0
if norm == "minmax":
min_depths, max_depths = utils.masked_minmax(sparses, masks, dims=(1, 2, 3))
min_depths = min_depths.view(-1, 1, 1, 1)
max_depths = max_depths.view(-1, 1, 1, 1)
elif norm == "percentile":
ranges = torch.stack(
[
torch.quantile(
s[m],
torch.tensor(
[percentile, 1 - percentile], device=sparses.device
),
)
for s, m in zip(sparses, masks, strict=True)
]
) # [N, 2]
min_depths = ranges[:, 0].view(-1, 1, 1, 1)
max_depths = ranges[:, 1].view(-1, 1, 1, 1)
elif norm == "const":
min_depths = torch.full((N, 1, 1, 1), min_depth, device=sparses.device)
max_depths = torch.full((N, 1, 1, 1), max_depth, device=sparses.device)
else:
raise ValueError(f"Unknown norm method: {norm}")
# Clamp depth values to [min_depth, max_depth]
if norm in ["minmax", "percentile"]:
min_depths = torch.clamp(min_depths, min=min_depth, max=max_depth)
max_depths = torch.clamp(max_depths, min=min_depth, max=max_depth)
sparses_clamped = torch.clamp(
sparses,
min=min_depths,
max=max_depths,
)
# Normalize sparse depth maps
proj_fn = get_projection_fn(projection)
min_depths_proj, max_depths_proj = proj_fn(min_depths), proj_fn(max_depths)
sparses_clamped_proj = proj_fn(sparses_clamped)
if inv:
min_depths_proj, max_depths_proj = 1 / max_depths_proj, 1 / min_depths_proj
sparses_clamped_proj = 1 / sparses_clamped_proj
sparses_normed = (sparses_clamped_proj - min_depths_proj) / (
max_depths_proj - min_depths_proj
)
sparses_min, sparses_max = utils.masked_minmax(
sparses_normed, masks, dims=(1, 2, 3)
)
sparses_min = sparses_min.view(-1, 1, 1, 1)
sparses_max = sparses_max.view(-1, 1, 1, 1)
sparse_ranges = (sparses_min, sparses_max) if affine_invariant else None
# Set current prediction latents as trainable params
pred_latents = torch.nn.Parameter(pred_latents) # [N, 4, EH, EW]
# Set scaling params
affine_params = (
(
# scale
torch.nn.Parameter(torch.ones(N, 1, 1, 1, device=self.device)),
# shift
torch.nn.Parameter(torch.zeros(N, 1, 1, 1, device=self.device)),
)
if affine_invariant
else None
)
# Set up optimizer
optimizer: Optimizer
param_groups = [
{"params": [pred_latents], "lr": lr_latent},
]
if affine_params is not None:
param_groups.append({"params": list(affine_params), "lr": lr_scaling})
if opt == "adam":
optimizer = Adam(param_groups)
elif opt == "sgd":
optimizer = SGD(param_groups)
elif opt == "adagrad":
optimizer = Adagrad(param_groups)
elif opt == "adadelta":
# NOTE: Adadelta uses a fixed learning rate of 1
for group in param_groups:
group["lr"] = 1
optimizer = Adadelta(param_groups)
else:
raise ValueError(f"Unknown optimizer: {opt}")
# Denoising loop
self.scheduler.set_timesteps(steps, device=self.device)
for t in self.scheduler.timesteps:
optimizer.zero_grad()
# Forward pass through the U-Net
latents = torch.cat([img_latents, pred_latents], dim=1) # [N, 8, EH, EW]
pred_noises: torch.Tensor = self.unet(
latents,
t,
encoder_hidden_states=batch_empty_text_embedding,
return_dict=False,
)[
0
] # [N, 4, EH, EW]
# Compute noise to later rescale the depth latent gradient
with torch.no_grad():
a_prod_t = cast(float, self.scheduler.alphas_cumprod[t])
b_prod_t = 1 - a_prod_t
pred_epsilons = (a_prod_t**0.5) * pred_noises + (
b_prod_t**0.5
) * pred_latents # [N, 4, EH, EW]
# Preview the final output depth with Tweedie's formula
previews = cast(
torch.Tensor,
self.scheduler.step(
pred_noises, t, pred_latents, generator=generator
).pred_original_sample,
) # [N, 4, EH, EW]
# Predict dense depth maps
denses_normed = self._latent_to_dense(
previews,
orig_res,
padding,
affine_invariant=affine_invariant,
affine_params=affine_params,
sparse_range=sparse_ranges,
interp_mode=interp_mode,
) # [N, 1, H, W]
denses_normed = denses_normed.clamp(min=0.0, max=1.0)
if projection != "linear":
denses_normed = denses_normed * (max_depths - min_depths) + min_depths
denses_normed = get_projection_fn(projection)(denses_normed)
if inv:
denses_normed = 1 / denses_normed
denses_normed = (denses_normed - min_depths_proj) / (
max_depths_proj - min_depths_proj
)
elif inv:
denses_normed = denses_normed * (max_depths - min_depths) + min_depths
denses_normed = 1 / denses_normed
denses_normed = (denses_normed - min_depths_proj) / (
max_depths_proj - min_depths_proj
)
losses = compute_loss(
denses_normed, sparses_normed, masks, loss_funcs, image=imgs
)
# NOTE: Add KL divergence penalty to keep
# the distribution of pred_latent close to N(0,1)
if kld:
kld_losses = utils.kld_stdnorm(
pred_latents, reduction="none", mode=kld_mode
).reshape(-1, 1, 1, 1)
losses = losses + kld_weight * kld_losses
# Backprop
losses.backward(torch.ones_like(losses)) # Preserve batch dimension
# NOTE: Scale grads of pred_latents by the norm of pred_epsilons
# for stable optimization
with torch.no_grad():
assert pred_latents.grad is not None
pred_epsilon_norms = torch.linalg.norm(
pred_epsilons.view(N, -1), dim=1
) # [N]
pred_latent_grad_norms = torch.linalg.norm(
pred_latents.grad.view(N, -1), dim=1
) # [N]
factors = pred_epsilon_norms / torch.clamp(
pred_latent_grad_norms, min=EPSILON
) # [N]
factors = factors.view(N, 1, 1, 1) # [N, 1, 1, 1]
# Scaling
pred_latents.grad *= factors # [N, 4, EH, EW]
# Backprop
optimizer.step()
# Execute update of the latent with regular denoising diffusion step
with torch.no_grad():
pred_latents.data = self.scheduler.step(
pred_noises, t, pred_latents, generator=generator
).prev_sample
# Compute final dense depth maps
with torch.no_grad():
pred_latents_detached = pred_latents.detach()
denses_normed = self._latent_to_dense(
pred_latents_detached,
orig_res,
padding,
affine_invariant=affine_invariant,
affine_params=affine_params,
sparse_range=sparse_ranges,
interp_mode=interp_mode,
) # [N, 1, H, W]
# Decode
denses_normed = torch.clamp(denses_normed, min=0, max=1)
denses = denses_normed * (max_depths - min_depths) + min_depths
return denses, pred_latents_detached