Simple autodiff test

This commit is contained in:
Enrico Lumetti 2021-08-13 11:07:39 +02:00
parent 417e34ea37
commit a81fcf7485
3 changed files with 136 additions and 10 deletions

View File

@ -29,6 +29,12 @@ version = "3.1.23"
[[Artifacts]] [[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[AutoGrad]]
deps = ["Libdl", "LinearAlgebra", "SpecialFunctions", "Statistics", "TimerOutputs"]
git-tree-sha1 = "16af9a724cf7cdcf8069b2a44300fba4c25f431b"
uuid = "6710c13c-97f1-543f-91c5-74e8f7d95b35"
version = "1.2.4"
[[BFloat16s]] [[BFloat16s]]
deps = ["LinearAlgebra", "Test"] deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a" git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a"
@ -80,10 +86,10 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0" version = "0.7.0"
[[ColorSchemes]] [[ColorSchemes]]
deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random", "StaticArrays"] deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random"]
git-tree-sha1 = "ed268efe58512df8c7e224d2e170afd76dd6a417" git-tree-sha1 = "9995eb3977fbf67b86d0a0a0508e83017ded03f2"
uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
version = "3.13.0" version = "3.14.0"
[[ColorTypes]] [[ColorTypes]]
deps = ["FixedPointNumbers", "Random"] deps = ["FixedPointNumbers", "Random"]
@ -91,6 +97,12 @@ git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0" version = "0.11.0"
[[ColorVectorSpace]]
deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"]
git-tree-sha1 = "42a9b08d3f2f951c9b283ea427d96ed9f1f30343"
uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4"
version = "0.9.5"
[[Colors]] [[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
@ -126,9 +138,9 @@ version = "1.7.0"
[[DataStructures]] [[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.9" version = "0.18.10"
[[DataValueInterfaces]] [[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
@ -198,6 +210,12 @@ git-tree-sha1 = "3cc57ad0a213808473eafef4845a74766242e05f"
uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5"
version = "4.3.1+4" version = "4.3.1+4"
[[FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
git-tree-sha1 = "256d8e6188f3f1ebfa1a5d17e072a0efafa8c5bf"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.10.1"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7" git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7"
@ -300,6 +318,12 @@ git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02"
uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
version = "2.68.3+0" version = "2.68.3+0"
[[Graphics]]
deps = ["Colors", "LinearAlgebra", "NaNMath"]
git-tree-sha1 = "2c1cf4df419938ece72de17f368a021ee162762e"
uuid = "a2bd30eb-e257-5431-a919-1863eab51364"
version = "1.1.0"
[[Grisu]] [[Grisu]]
git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2"
uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe"
@ -322,6 +346,24 @@ git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.0" version = "0.1.0"
[[ImageCore]]
deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"]
git-tree-sha1 = "75f7fea2b3601b58f24ee83617b528e57160cbfd"
uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534"
version = "0.9.1"
[[ImageMagick]]
deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils", "Libdl", "Pkg", "Random"]
git-tree-sha1 = "5bc1cb62e0c5f1005868358db0692c994c3a13c6"
uuid = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
version = "1.2.1"
[[ImageMagick_jll]]
deps = ["JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pkg", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "1c0a2295cca535fabaf2029062912591e9b61987"
uuid = "c73af94c-d91f-53ed-93a7-00f77d67a9d7"
version = "6.9.10-12+3"
[[IniFile]] [[IniFile]]
deps = ["Test"] deps = ["Test"]
git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8"
@ -332,6 +374,11 @@ version = "0.5.0"
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[IrrationalConstants]]
git-tree-sha1 = "f76424439413893a832026ca355fe273e93bce94"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.0"
[[IterTools]] [[IterTools]]
git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18" git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18"
uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
@ -342,6 +389,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d" uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0" version = "1.0.0"
[[JLD2]]
deps = ["DataStructures", "FileIO", "MacroTools", "Mmap", "Pkg", "Printf", "Reexport", "TranscodingStreams", "UUIDs"]
git-tree-sha1 = "59ee430ac5dc87bc3eec833cc2a37853425750b4"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
version = "0.4.13"
[[JLLWrappers]] [[JLLWrappers]]
deps = ["Preferences"] deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
@ -366,6 +419,12 @@ git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.8.4" version = "0.8.4"
[[Knet]]
deps = ["AutoGrad", "CUDA", "FileIO", "ImageCore", "ImageMagick", "JLD2", "Libdl", "LinearAlgebra", "NNlib", "Pkg", "Printf", "Random", "Serialization", "SpecialFunctions", "Statistics"]
git-tree-sha1 = "447d15dcfce6ee38f81b2cbd21289ea409fb59e5"
uuid = "1902f260-5fb4-5aff-8c31-6271790ab950"
version = "1.4.8"
[[LAME_jll]] [[LAME_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c"
@ -483,10 +542,10 @@ deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[LogExpFunctions]] [[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070" git-tree-sha1 = "3d682c07e6dd250ed082f883dc88aee7996bf2cc"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.5" version = "0.3.0"
[[Logging]] [[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@ -497,6 +556,11 @@ git-tree-sha1 = "0fb723cd8c45858c22169b2e42269e53271a6df7"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.7" version = "0.5.7"
[[MappedArrays]]
git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142"
uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
version = "0.4.1"
[[Markdown]] [[Markdown]]
deps = ["Base64"] deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
@ -531,6 +595,12 @@ version = "1.0.0"
[[Mmap]] [[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[MosaicViews]]
deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"]
git-tree-sha1 = "b34e3bc3ca7c94914418637cb10cc4d1d80d877d"
uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
version = "0.3.3"
[[MozillaCACerts_jll]] [[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159" uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
@ -554,6 +624,12 @@ version = "0.3.5"
[[NetworkOptions]] [[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OffsetArrays]]
deps = ["Adapt"]
git-tree-sha1 = "c0f4a4836e5f3e0763243b8324200af6d0e0f90c"
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.10.5"
[[Ogg_jll]] [[Ogg_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f" git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f"
@ -589,6 +665,12 @@ git-tree-sha1 = "b2a7af664e098055a7529ad1a900ded962bca488"
uuid = "2f80f16e-611a-54ab-bc61-aa92de5b98fc" uuid = "2f80f16e-611a-54ab-bc61-aa92de5b98fc"
version = "8.44.0+0" version = "8.44.0+0"
[[PaddedViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "646eed6f6a5d8df6708f15ea7e02a7a2c4fe4800"
uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
version = "0.5.10"
[[Parsers]] [[Parsers]]
deps = ["Dates"] deps = ["Dates"]
git-tree-sha1 = "477bf42b4d1496b454c10cce46645bb5b8a0cf2c" git-tree-sha1 = "477bf42b4d1496b454c10cce46645bb5b8a0cf2c"
@ -722,9 +804,15 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]] [[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8" git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b" uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.6.0" version = "1.6.1"
[[StackViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c"
uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
version = "0.1.1"
[[Static]] [[Static]]
deps = ["IfElse"] deps = ["IfElse"]
@ -779,6 +867,12 @@ version = "1.5.0"
deps = ["ArgTools", "SHA"] deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[TensorCore]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6"
uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
version = "0.1.1"
[[Test]] [[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

View File

@ -5,4 +5,6 @@ version = "0.1.0"
[deps] [deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Knet = "1902f260-5fb4-5aff-8c31-6271790ab950"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

30
scratchpad/autodiff.jl Normal file
View File

@ -0,0 +1,30 @@
using Zygote
using LinearAlgebra: norm
g(x::Vector{<:Real}) = 2x'x
v = [1, 3, 6]
@assert g'(v) == 4*v
function f(x)
if x > 0
return 30x
else
return 50x
end
end
@assert f'(3) == 30
@assert f'(-1) == 50
function affine(A, b)
return (x) -> A*x + b
end
A = [1 2; 3 4]
b = [30, 30]
h = affine(A, b)
@assert jacobian(h, [1,2]) == (A,)