-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathshonan_live.scala
111 lines (84 loc) · 2.2 KB
/
shonan_live.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package scala.lms.tutorial
import scala.lms.common._
import scala.reflect.SourceContext
// sbt ~testOnly *ShonanLive*
class ShonanLiveTest extends TutorialFunSuite {
val under = "0-live-"
/* playground for live demo */
test("shonan-hmm-live") {
val A = scala.Array
val a =
A(A(1, 1, 1, 1, 1), // dense
A(0, 0, 0, 0, 0), // null
A(0, 0, 1, 0, 0), // sparse
A(0, 0, 0, 0, 0),
A(0, 0, 1, 0, 1))
def matrix_vector_prod(a: Array[Array[Int]], v: Array[Int]) = {
val n = a.length
val v1 = new Array[Int](n)
for (i <- (0 until n): Range) {
for (j <- (0 until n)) {
v1(i) = v1(i) + a(i).apply(j) * v(j)
}
}
v1
}
// let's run it on some static input:
val v = A(1,1,1,1,1)
val v1 = matrix_vector_prod(a, v)
val result = v1.mkString(",")
check("shonan-hmm-live", result)
}
/*
val snippet = new LMS_Driver[Array[Int],Array[Int]] {
def snippet(v: Rep[Array[Int]]) = {
println("hello")
v
}
}
exec("shonan-hmm-live", snippet.code)
*/
/*
DEMO:
1) add compile snippet
2) add conditional
3) stage mv prod
- staticData(a)
- NewArray[Int](n)
4) Range vs Rep[Range]
5) unrollIf
*/
abstract class LMS_Driver[A:Manifest,B:Manifest] extends DslDriver[A,B]
}
/*
val snippet = new LMS_Driver[Array[Int],Array[Int]] {
def unrollIf(c: Boolean)(r: Range) = new {
def foreach(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
if (c) {
for (j <- r) {
f(j)
}
} else {
for (j <- (r.start until r.end):Rep[Range]) {
f(j)
}
}
}
}
def matrix_vector_prod(a0: Array[Array[Int]], v: Rep[Array[Int]]) = {
val n = a0.length
val v1 = NewArray[Int](n)
val a = staticData(a0)
for (i <- (0 until n):Range) {
val sparse = a0(i).count(_ != 0) < 3
for (j <- unrollIf(sparse)(0 until n)) {
v1(i) = v1(i) + a(i).apply(j) * v(j)
}
}
v1
}
def snippet(v: Rep[Array[Int]]) = {
matrix_vector_prod(a,v)
}
}
*/