[MLC-机器学习编译]03作业: TensorIR练习
笔记
parallel
需要在decompose_reduction
前, 参考TA解答.- Definition of
compute_at
: Given a producer and a consumer, compute_at allows to compute part of the producer’s region under one of the consumer’s loop. - Definition of
reverse_compute_at
: Given a producer and a consumer, reverse_compute_at allows to compute part of the consumer’s region under one of the producer’s loop. decompose_reduction
后循环结构发生了些变化, 需要重新get_loops
.
Code
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
# numpy version
c_np = a + b
c_np
array([[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16]])
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
for i in range(4):
for j in range(4):
c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy
array([[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16]])
# TensorIR version
@tvm.script.ir_module
class MyAdd:
@T.prim_func
def add(A: T.Buffer[(4, 4), "int64"],
B: T.Buffer[(4, 4), "int64"],
C: T.Buffer[(4, 4), "int64"]):
T.func_attr({"global_symbol": "add"})
for i, j in T.grid(4, 4):
with T.block("C"):
vi = T.axis.spatial(4, i)
vj = T.axis.spatial(4, j)
C[vi, vj] = A[vi, vj] + B[vi, vj]
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
# numpy version
c_np = a + b
c_np
array([[ 4, 4, 4, 4],
[ 8, 8, 8, 8],
[12, 12, 12, 12],
[16, 16, 16, 16]])
@tvm.script.ir_module
class MyAdd:
@T.prim_func
def add(A: T.Buffer[(4, 4), "int64"],
B: T.Buffer[(4), "int64"],
C: T.Buffer[(4, 4), "int64"]):
T.func_attr({"global_symbol": "add", "tir.noalias": True})
for i, j in T.grid(4, 4):
with T.block("C"):
vi = T.axis.spatial(4, i)
vj = T.axis.spatial(4, j)
C[vi, vj] = A[vi, vj] + B[vj]
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
# torch version
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch
array([[[[ 474, 510, 546, 582, 618, 654],
[ 762, 798, 834, 870, 906, 942],
[1050, 1086, 1122, 1158, 1194, 1230],
[1338, 1374, 1410, 1446, 1482, 1518],
[1626, 1662, 1698, 1734, 1770, 1806],
[1914, 1950, 1986, 2022, 2058, 2094]],
[[1203, 1320, 1437, 1554, 1671, 1788],
[2139, 2256, 2373, 2490, 2607, 2724],
[3075, 3192, 3309, 3426, 3543, 3660],
[4011, 4128, 4245, 4362, 4479, 4596],
[4947, 5064, 5181, 5298, 5415, 5532],
[5883, 6000, 6117, 6234, 6351, 6468]]]])
@tvm.script.ir_module
class MyConv:
@T.prim_func
def conv(inputs: T.Buffer[(1, 1, 8, 8), "int64"],
weight: T.Buffer[(2, 1, 3, 3), "int64"],
output: T.Buffer[(1, 2, 6, 6), "int64"]):
T.func_attr({"global_symbol": "conv", "tir.noalias": True})
for b, ci, co, i, j, ki, kj in T.grid(1, 1, 2, 6, 6, 3, 3):
with T.block("C"):
vb, vci, vco, vi, vj, vki, vkj = T.axis.remap("SRSSSRR", [b, ci, co, i, j, ki, kj])
with T.init():
output[vb, vco, vi, vj] = T.int64(0)
output[vb, vco, vi, vj] = output[vb, vco, vi, vj] + inputs[vb, vci, vi + vki, vj + vkj] * weight[vco, vci, vki, vkj]
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)
@tvm.script.ir_module
class MyAdd:
@T.prim_func
def add(A: T.Buffer[(4, 4), "int64"],
B: T.Buffer[(4, 4), "int64"],
C: T.Buffer[(4, 4), "int64"]):
T.func_attr({"global_symbol": "add"})
for i, j in T.grid(4, 4):
with T.block("C"):
vi = T.axis.spatial(4, i)
vj = T.axis.spatial(4, j)
C[vi, vj] = A[vi, vj] + B[vi, vj]
sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.Code(sch.mod.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def add(A: tir.Buffer[(4, 4), "int64"], B: tir.Buffer[(4, 4), "int64"], C: tir.Buffer[(4, 4), "int64"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "add"})
# body
# with tir.block("root")
for i_0 in tir.parallel(2):
for i_1 in tir.unroll(2):
for j in tir.vectorized(4):
with tir.block("C"):
vi = tir.axis.spatial(4, i_0 * 2 + i_1)
vj = tir.axis.spatial(4, j)
tir.reads(A[vi, vj], B[vi, vj])
tir.writes(C[vi, vj])
C[vi, vj] = A[vi, vj] + B[vi, vj]
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((16, 128, 128), dtype="float32")
for n in range(16):
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[n, i, j] = 0
Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
for n in range(16):
for i in range(128):
for j in range(128):
C[n, i, j] = max(Y[n, i, j], 0)
@tvm.script.ir_module
class MyBmmRelu:
@T.prim_func
def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
B: T.Buffer[(16, 128, 128), "float32"],
C: T.Buffer[(16, 128, 128), "float32"]):
T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
Y = T.alloc_buffer((16, 128, 128), dtype="float32")
for n, i, j, k in T.grid(16, 128, 128, 128):
with T.block("Y"):
vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
with T.init():
Y[vn, vi, vj] = T.float32(0)
Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
for n, i, j in T.grid(16, 128, 128):
with T.block("C"):
vn, vi, vj = T.axis.remap("SSS", [n, i, j])
C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result
a = np.arange(-8 * 128 * 128, 8 * 128 * 128, dtype=np.float32).reshape(16, 128, 128)
b = np.arange(8 * 128 * 128, -8 * 128 * 128, -1, dtype=np.float32).reshape(16, 128, 128)
c_lnumpy = np.empty((16, 128, 128), dtype=np.float32)
lnumpy_mm_relu_v2(a, b, c_lnumpy)
rt_lib = tvm.build(MyBmmRelu, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty(c_lnumpy.shape, dtype=np.float32))
rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_lnumpy, rtol=1e-5)
@tvm.script.ir_module
class TargetModule:
@T.prim_func
def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
Y = T.alloc_buffer([16, 128, 128], dtype="float32")
for i0 in T.parallel(16):
for i1, i2_0 in T.grid(128, 16):
for ax0_init in T.vectorized(8):
with T.block("Y_init"):
n, i = T.axis.remap("SS", [i0, i1])
j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
Y[n, i, j] = T.float32(0)
for ax1_0 in T.serial(32):
for ax1_1 in T.unroll(4):
for ax0 in T.serial(8):
with T.block("Y_update"):
n, i = T.axis.remap("SS", [i0, i1])
j = T.axis.spatial(128, i2_0 * 8 + ax0)
k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
for i2_1 in T.vectorized(8):
with T.block("C"):
n, i = T.axis.remap("SS", [i0, i1])
j = T.axis.spatial(128, i2_0 * 8 + i2_1)
C[n, i, j] = T.max(Y[n, i, j], T.float32(0))
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.
# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")
# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)
# Step 3. parallel
sch.parallel(b)
# Step 4. Organize the loops
k0, k1 = sch.split(k, [32, 4])
j0, j1 = sch.split(j, [16, 8])
sch.reorder(j0, j1, k0, k1)
sch.reverse_compute_at(C, j0)
# Step 5. decompose reduction
Y_init = sch.decompose_reduction(Y, j1)
Y_update = sch.get_block("Y_update", func_name="bmm_relu")
_, _, _, j1_i = sch.get_loops(Y_init)
_, _, _, j1_u, k0_u, k1_u = sch.get_loops(Y_update)
sch.reorder(k0_u, k1_u, j1_u)
_, _, _, ax0 = sch.get_loops(C)
# Step 6. vectorize / unroll
sch.vectorize(j1_i)
sch.vectorize(ax0)
sch.unroll(k1)
IPython.display.Code(sch.mod.script(), language="python")
@tvm.script.ir_module
class Module:
@tir.prim_func
def bmm_relu(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
# body
# with tir.block("root")
Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
for n in tir.parallel(16):
for i, j_0 in tir.grid(128, 16):
for j_1_init in tir.vectorized(8):
with tir.block("Y_init"):
vn, vi = tir.axis.remap("SS", [n, i])
vj = tir.axis.spatial(128, j_0 * 8 + j_1_init)
tir.reads()
tir.writes(Y[vn, vi, vj])
Y[vn, vi, vj] = tir.float32(0)
for k_0 in tir.serial(32):
for k_1 in tir.unroll(4):
for j_1 in tir.serial(8):
with tir.block("Y_update"):
vn, vi = tir.axis.remap("SS", [n, i])
vj = tir.axis.spatial(128, j_0 * 8 + j_1)
vk = tir.axis.reduce(128, k_0 * 4 + k_1)
tir.reads(Y[vn, vi, vj], A[vn, vi, vk], B[vn, vk, vj])
tir.writes(Y[vn, vi, vj])
Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
for ax0 in tir.vectorized(8):
with tir.block("C"):
vn, vi = tir.axis.remap("SS", [n, i])
vj = tir.axis.spatial(128, j_0 * 8 + ax0)
tir.reads(Y[vn, vi, vj])
tir.writes(C[vn, vi, vj])
C[vn, vi, vj] = tir.max(Y[vn, vi, vj], tir.float32(0))
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")
Pass
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))
f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))
Before transformation:
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
58.5732 58.5732 58.5732 58.5732 0.0000
After transformation:
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
14.9765 14.9765 14.9765 14.9765 0.0000
Comments