2010年5月29日土曜日

Perceptron Clojure

The algorithm:


(defn g [sum]
(if (> sum 0) 1 0))

(defn perceptron-predict
[weights input g]
(map #(g (reduce + 0 (map * % input))) weights))

(defn perceptron-training-step
[weights input target eta g]
(let [ys (perceptron-predict weights input g)]
(map (fn [w_i_k y_k t_k]
(map #(+ %1 (* eta (- t_k y_k) %2)) w_i_k input))
weights ys target)))

(defn perceptron-train
[weights inputs targets eta g max-iteration]
(prn (map #(perceptron-predict weights % g) inputs))
(let [new-weights
(reduce (fn [w [input target]] (perceptron-training-step w input target eta g)) weights (map list inputs targets))]
(println (str "ITERATION: " (vec new-weights) " / (max iterations: " max-iteration ")"))
(if (or (< max-iteration 2) (= weights new-weights))
(do
(prn (map #(perceptron-predict new-weights % g) inputs))
new-weights)
(recur new-weights inputs targets eta g (dec max-iteration)))))


Let's see if it can approximate a simple OR function (the '-1' in the inputs is the bias):


user=> (perceptron-train [[3 2 1]] [[-1 0 0] [-1 0 1] [-1 1 0] [-1 1 1]] [[0] [1] [1] [1]] 0.25 g 10)
((0) (0) (0) (0))
ITERATION: [(2.5 2.25 1.25)] / (max iterations: 10)
((0) (0) (0) (1))
ITERATION: [(2.0 2.5 1.5)] / (max iterations: 9)
((0) (0) (1) (1))
ITERATION: [(1.75 2.5 1.75)] / (max iterations: 8)
((0) (0) (1) (1))
ITERATION: [(1.5 2.5 2.0)] / (max iterations: 7)
((0) (1) (1) (1))
ITERATION: [(1.5 2.5 2.0)] / (max iterations: 6)
((0) (1) (1) (1))
((1.5 2.5 2.0))


What about XOR? It's not linearly separable, so it should not converge.


user=> (perceptron-train [[3 2 1]] [[-1 0 0] [-1 0 1] [-1 1 0] [-1 1 1]] [[0] [1] [1] [0]] 0.25 g 10)
((0) (0) (0) (0))
ITERATION: [(2.75 2.0 1.0)] / (max iterations: 10)
((0) (0) (0) (1))
ITERATION: [(2.5 2.0 1.0)] / (max iterations: 9)
((0) (0) (0) (1))
ITERATION: [(2.25 2.0 1.0)] / (max iterations: 8)
((0) (0) (0) (1))
ITERATION: [(2.0 2.0 1.0)] / (max iterations: 7)
((0) (0) (0) (1))
ITERATION: [(2.0 1.75 1.0)] / (max iterations: 6)
((0) (0) (0) (1))
ITERATION: [(1.75 1.75 1.0)] / (max iterations: 5)
((0) (0) (0) (1))
ITERATION: [(1.75 1.5 1.0)] / (max iterations: 4)
((0) (0) (0) (1))
ITERATION: [(1.5 1.5 1.0)] / (max iterations: 3)
((0) (0) (0) (1))
ITERATION: [(1.5 1.25 1.0)] / (max iterations: 2)
((0) (0) (0) (1))
ITERATION: [(1.25 1.25 1.0)] / (max iterations: 1)
((0) (0) (0) (1))
((1.25 1.25 1.0))

0 件のコメント:

コメントを投稿

フォロワー