|
17 | 17 | import math
|
18 | 18 | import torch
|
19 | 19 | from typing import List
|
| 20 | +from ipex_llm.utils.common import invalidInputError |
20 | 21 |
|
21 | 22 |
|
22 | 23 | def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
@@ -303,3 +304,56 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
303 | 304 | )
|
304 | 305 | attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
305 | 306 | return attn_output
|
| 307 | + |
| 308 | + |
| 309 | +def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int): |
| 310 | + if weight.device.type == "xpu": |
| 311 | + new_shape = x.shape[:-1] + (out_features,) |
| 312 | + x = x.to(weight.device, dtype=torch.float16) |
| 313 | + x_2d = x.contiguous().view(-1, x.shape[-1]) |
| 314 | + import xe_linear |
| 315 | + x = xe_linear.forward_new(x_2d, weight, qtype, out_features) |
| 316 | + x = x.view(new_shape) |
| 317 | + return x |
| 318 | + else: |
| 319 | + invalidInputError(False, |
| 320 | + "Unsupported device type: only support weight on xpu device.") |
| 321 | + |
| 322 | + |
| 323 | +def quantize_linear(weight: torch.Tensor, in_features: int, precision: str): |
| 324 | + from ipex_llm.transformers.low_bit_linear import FP4Params |
| 325 | + from ipex_llm.ggml.quantize import ggml_tensor_qtype |
| 326 | + |
| 327 | + invalidInputError(precision in ggml_tensor_qtype.keys(), |
| 328 | + f"{precision} is not supported, " |
| 329 | + f"only {ggml_tensor_qtype.keys()} are supported now.") |
| 330 | + qtype = ggml_tensor_qtype[precision] |
| 331 | + paramsLowBit = FP4Params(data=weight.data, |
| 332 | + requires_grad=False, |
| 333 | + quantized=False, |
| 334 | + _shape=None, |
| 335 | + convert_shape_only=False, |
| 336 | + qtype=qtype, |
| 337 | + in_features=in_features, |
| 338 | + enable_scale_search=False).to("cpu") |
| 339 | + return paramsLowBit, qtype |
| 340 | + |
| 341 | + |
| 342 | +def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor, |
| 343 | + n_group: int, topk_group: int, top_k: int, norm_topk_prob: float, |
| 344 | + routed_scaling_factor: float): |
| 345 | + import xe_addons |
| 346 | + topk_idx, topk_weight = xe_addons.moe_group_topk( |
| 347 | + scores, e_score_correction_bias, |
| 348 | + n_group, 2, topk_group, top_k, |
| 349 | + top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor |
| 350 | + ) |
| 351 | + return topk_idx, topk_weight |
| 352 | + |
| 353 | + |
| 354 | +def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor, |
| 355 | + cos: torch.Tensor, sin: torch.Tensor, |
| 356 | + half_layout: bool): |
| 357 | + import xe_addons |
| 358 | + xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, |
| 359 | + cos, sin, half_layout) |
0 commit comments