From a81fcf74850af79544a80811466614729419e958 Mon Sep 17 00:00:00 2001 From: Enrico Lumetti Date: Fri, 13 Aug 2021 11:07:39 +0200 Subject: [PATCH] Simple autodiff test --- Manifest.toml | 114 +++++++++++++++++++++++++++++++++++++---- Project.toml | 2 + scratchpad/autodiff.jl | 30 +++++++++++ 3 files changed, 136 insertions(+), 10 deletions(-) create mode 100644 scratchpad/autodiff.jl diff --git a/Manifest.toml b/Manifest.toml index 4aa8785..4511d5a 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -29,6 +29,12 @@ version = "3.1.23" [[Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +[[AutoGrad]] +deps = ["Libdl", "LinearAlgebra", "SpecialFunctions", "Statistics", "TimerOutputs"] +git-tree-sha1 = "16af9a724cf7cdcf8069b2a44300fba4c25f431b" +uuid = "6710c13c-97f1-543f-91c5-74e8f7d95b35" +version = "1.2.4" + [[BFloat16s]] deps = ["LinearAlgebra", "Test"] git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a" @@ -80,10 +86,10 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.0" [[ColorSchemes]] -deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random", "StaticArrays"] -git-tree-sha1 = "ed268efe58512df8c7e224d2e170afd76dd6a417" +deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random"] +git-tree-sha1 = "9995eb3977fbf67b86d0a0a0508e83017ded03f2" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.13.0" +version = "3.14.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -91,6 +97,12 @@ git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" version = "0.11.0" +[[ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] +git-tree-sha1 = "42a9b08d3f2f951c9b283ea427d96ed9f1f30343" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.9.5" + [[Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" @@ -126,9 +138,9 @@ version = "1.7.0" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" +git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.9" +version = "0.18.10" [[DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -198,6 +210,12 @@ git-tree-sha1 = "3cc57ad0a213808473eafef4845a74766242e05f" uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" version = "4.3.1+4" +[[FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "256d8e6188f3f1ebfa1a5d17e072a0efafa8c5bf" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.10.1" + [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7" @@ -300,6 +318,12 @@ git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" version = "2.68.3+0" +[[Graphics]] +deps = ["Colors", "LinearAlgebra", "NaNMath"] +git-tree-sha1 = "2c1cf4df419938ece72de17f368a021ee162762e" +uuid = "a2bd30eb-e257-5431-a919-1863eab51364" +version = "1.1.0" + [[Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" @@ -322,6 +346,24 @@ git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef" uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" version = "0.1.0" +[[ImageCore]] +deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] +git-tree-sha1 = "75f7fea2b3601b58f24ee83617b528e57160cbfd" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.9.1" + +[[ImageMagick]] +deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils", "Libdl", "Pkg", "Random"] +git-tree-sha1 = "5bc1cb62e0c5f1005868358db0692c994c3a13c6" +uuid = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" +version = "1.2.1" + +[[ImageMagick_jll]] +deps = ["JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pkg", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "1c0a2295cca535fabaf2029062912591e9b61987" +uuid = "c73af94c-d91f-53ed-93a7-00f77d67a9d7" +version = "6.9.10-12+3" + [[IniFile]] deps = ["Test"] git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" @@ -332,6 +374,11 @@ version = "0.5.0" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[IrrationalConstants]] +git-tree-sha1 = "f76424439413893a832026ca355fe273e93bce94" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.0" + [[IterTools]] git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18" uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" @@ -342,6 +389,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" +[[JLD2]] +deps = ["DataStructures", "FileIO", "MacroTools", "Mmap", "Pkg", "Printf", "Reexport", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "59ee430ac5dc87bc3eec833cc2a37853425750b4" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.13" + [[JLLWrappers]] deps = ["Preferences"] git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" @@ -366,6 +419,12 @@ git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" version = "0.8.4" +[[Knet]] +deps = ["AutoGrad", "CUDA", "FileIO", "ImageCore", "ImageMagick", "JLD2", "Libdl", "LinearAlgebra", "NNlib", "Pkg", "Printf", "Random", "Serialization", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "447d15dcfce6ee38f81b2cbd21289ea409fb59e5" +uuid = "1902f260-5fb4-5aff-8c31-6271790ab950" +version = "1.4.8" + [[LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" @@ -483,10 +542,10 @@ deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[LogExpFunctions]] -deps = ["DocStringExtensions", "LinearAlgebra"] -git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070" +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "3d682c07e6dd250ed082f883dc88aee7996bf2cc" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.2.5" +version = "0.3.0" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -497,6 +556,11 @@ git-tree-sha1 = "0fb723cd8c45858c22169b2e42269e53271a6df7" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.7" +[[MappedArrays]] +git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.1" + [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -531,6 +595,12 @@ version = "1.0.0" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "b34e3bc3ca7c94914418637cb10cc4d1d80d877d" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.3" + [[MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" @@ -554,6 +624,12 @@ version = "0.3.5" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +[[OffsetArrays]] +deps = ["Adapt"] +git-tree-sha1 = "c0f4a4836e5f3e0763243b8324200af6d0e0f90c" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.10.5" + [[Ogg_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f" @@ -589,6 +665,12 @@ git-tree-sha1 = "b2a7af664e098055a7529ad1a900ded962bca488" uuid = "2f80f16e-611a-54ab-bc61-aa92de5b98fc" version = "8.44.0+0" +[[PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "646eed6f6a5d8df6708f15ea7e02a7a2c4fe4800" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.10" + [[Parsers]] deps = ["Dates"] git-tree-sha1 = "477bf42b4d1496b454c10cce46645bb5b8a0cf2c" @@ -722,9 +804,15 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] -git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8" +git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.6.0" +version = "1.6.1" + +[[StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" [[Static]] deps = ["IfElse"] @@ -779,6 +867,12 @@ version = "1.5.0" deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +[[TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + [[Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/Project.toml b/Project.toml index 1f865b5..cf826cb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,4 +5,6 @@ version = "0.1.0" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Knet = "1902f260-5fb4-5aff-8c31-6271790ab950" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/scratchpad/autodiff.jl b/scratchpad/autodiff.jl new file mode 100644 index 0000000..eb7af86 --- /dev/null +++ b/scratchpad/autodiff.jl @@ -0,0 +1,30 @@ +using Zygote +using LinearAlgebra: norm + +g(x::Vector{<:Real}) = 2x'x + +v = [1, 3, 6] +@assert g'(v) == 4*v + +function f(x) + if x > 0 + return 30x + else + return 50x + end +end + +@assert f'(3) == 30 +@assert f'(-1) == 50 + +function affine(A, b) + return (x) -> A*x + b +end + +A = [1 2; 3 4] +b = [30, 30] + +h = affine(A, b) + +@assert jacobian(h, [1,2]) == (A,) +