Skip to content

Commit 2f78afc

Browse files
authored
Refactor some functions to ipex_llm.transformers.models.common (#13091)
* add quantize_linear & linear_forward * add moe_group_topk * rotary_two_with_cache_inplaced * fix code style * update related models
1 parent 73198d5 commit 2f78afc

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

python/llm/src/ipex_llm/transformers/models/common.py

+54
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import math
1818
import torch
1919
from typing import List
20+
from ipex_llm.utils.common import invalidInputError
2021

2122

2223
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
@@ -303,3 +304,56 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
303304
)
304305
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
305306
return attn_output
307+
308+
309+
def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int):
310+
if weight.device.type == "xpu":
311+
new_shape = x.shape[:-1] + (out_features,)
312+
x = x.to(weight.device, dtype=torch.float16)
313+
x_2d = x.contiguous().view(-1, x.shape[-1])
314+
import xe_linear
315+
x = xe_linear.forward_new(x_2d, weight, qtype, out_features)
316+
x = x.view(new_shape)
317+
return x
318+
else:
319+
invalidInputError(False,
320+
"Unsupported device type: only support weight on xpu device.")
321+
322+
323+
def quantize_linear(weight: torch.Tensor, in_features: int, precision: str):
324+
from ipex_llm.transformers.low_bit_linear import FP4Params
325+
from ipex_llm.ggml.quantize import ggml_tensor_qtype
326+
327+
invalidInputError(precision in ggml_tensor_qtype.keys(),
328+
f"{precision} is not supported, "
329+
f"only {ggml_tensor_qtype.keys()} are supported now.")
330+
qtype = ggml_tensor_qtype[precision]
331+
paramsLowBit = FP4Params(data=weight.data,
332+
requires_grad=False,
333+
quantized=False,
334+
_shape=None,
335+
convert_shape_only=False,
336+
qtype=qtype,
337+
in_features=in_features,
338+
enable_scale_search=False).to("cpu")
339+
return paramsLowBit, qtype
340+
341+
342+
def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor,
343+
n_group: int, topk_group: int, top_k: int, norm_topk_prob: float,
344+
routed_scaling_factor: float):
345+
import xe_addons
346+
topk_idx, topk_weight = xe_addons.moe_group_topk(
347+
scores, e_score_correction_bias,
348+
n_group, 2, topk_group, top_k,
349+
top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor
350+
)
351+
return topk_idx, topk_weight
352+
353+
354+
def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
355+
cos: torch.Tensor, sin: torch.Tensor,
356+
half_layout: bool):
357+
import xe_addons
358+
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
359+
cos, sin, half_layout)

python/llm/src/ipex_llm/transformers/models/deepseek.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ def deepseek_attention_forward(
228228
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
229229
dim=-1
230230
)
231-
import xe_addons
232231
cos, sin = position_embeddings
233-
xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
234-
key_states[:, :, :, self.qk_nope_head_dim:],
235-
cos, sin, True)
232+
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
233+
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
234+
key_states[:, :, :, self.qk_nope_head_dim:],
235+
cos, sin, True)
236236
else:
237237
q_nope, q_pe = torch.split(
238238
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
@@ -279,11 +279,11 @@ def fuse_gate_forward(self, x: torch.Tensor):
279279
)
280280
scores = logits.sigmoid()
281281

282-
import xe_addons
283-
topk_idx, topk_weight = xe_addons.moe_group_topk(
282+
from ipex_llm.transformers.models.common import moe_group_topk
283+
topk_idx, topk_weight = moe_group_topk(
284284
scores, self.e_score_correction_bias,
285-
self.n_group, 2, self.topk_group, self.top_k,
286-
self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor
285+
self.n_group, self.topk_group, self.top_k,
286+
self.norm_topk_prob, self.routed_scaling_factor
287287
)
288288
else:
289289
topk_idx, topk_weight = self(x)

python/llm/src/ipex_llm/transformers/models/glm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def glm_attention_forward(
9898

9999
cos, sin = position_embeddings
100100
if query_states.device.type == "xpu":
101-
import xe_addons
102101
make_cache_contiguous_inplaced(cos, sin)
103-
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
102+
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
103+
rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
104104
else:
105105
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
106106

0 commit comments

Comments
 (0)