Skip to content

Fix cuda compile error with bf16 #2122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 7, 2025

Conversation

metascroy
Copy link
Contributor

Summary: T222166791

Differential Revision: D73562284

Summary: T222166791

Differential Revision: D73562284
Copy link

pytorch-bot bot commented Apr 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2122

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 77e226c with merge base 31d17c0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73562284

@@ -70,9 +70,9 @@ constexpr float power_of_two(int n) {
return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1);
}

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this look reversed to me? should this be if defined and >= 800?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that was weird as well. But I see it defined in this way in multiple files in torchao (they could all be wrong).

But what I'm doing in this PR is matching the if/else macro on the import (lines 27-29):

#include <cuda_fp16.h>
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif

with this function.

__nv_bfloat16 is defined in cuda_bf16.h, but that is currently guarded by #if !defined(CUDA_ARCH) || CUDA_ARCH >= 800.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this does look weird

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gau-nernst can you take a look, is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the original error? (i.e. why is this PR needed?)

__CUDA_ARCH__ is only defined for device (CUDA) code. #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 means it is true if it is host code OR CUDA code with sm>=80. There are some rules about how we should use it https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-arch, otherwise it will lead to undefined behavior.

For the example you mentioned above, we need to also include cuda_bf16.h header in host code, otherwise host code won't have access to BF16 typedef. Similarly, __global__ functions (CUDA kernels) must be visible in both host code and device code (and have the same signature). Hence, in many cases, it's easier to have all the functions defined, and leave the implementation empty if __CUDA_ARCH__ < xxx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rather than this change, I should remove the guards around cuda_bf16.h in lines 28-30, i.e., change:

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif

to:

#include <cuda_bf16.h>

Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and removed the guards on the header (#if !defined(CUDA_ARCH) || CUDA_ARCH >= 800)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any more concerns here @jerryzh168 @gau-nernst?

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy added topic: bug fix Use this tag for PRs that fix bugs topic: not user facing Use this tag if you don't want this PR to show up in release notes labels May 6, 2025
@facebook-github-bot facebook-github-bot merged commit e5d9a97 into pytorch:main May 7, 2025
20 of 22 checks passed
andrewor14 pushed a commit that referenced this pull request May 9, 2025
Differential Revision: D73562284

Pull Request resolved: #2122
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: bug fix Use this tag for PRs that fix bugs topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants