본문 바로가기
python 및 머신러닝/집단지성 프로그래밍

[Programming Collective Intelligence] - 집단지성 프로그래밍 4장 7절 클릭 학습

by java개발자 2015. 9. 2.
# -*- coding: utf-8 -*-
'''
Created on 2015. 9. 2.

@author: Administrator
'''
'''
  4-07절 클릭 학습
  > 인공 신경망 이용(ANN)
  
  이전 실습과 차이점 : 현재 작업을 하기 위해 우선되어야 하는 작업을 명시하였다.. practice0()....
  
  인공신경망 - MLP(다중층 인식망)네트워크
    > 이해가 안된다-_-;;
    
  전방전파
    입력층(단어층) -> 은닉층 -> 출력층 의 순으로 가중치가 전파된다.(강도strength를 이용해서)
  역전파
    반대로...
    
    
  질문
    2개 이상의 trainquery을 1회 이상 실행했을 때, wordhidden 테이블에 의도치 않은 데이터가 생긴다. 
    getallhiddenids함수의 hiddenurl 테이블에서 특정 url과 관련된 모든 hiddenid를 가져올때, 의도치 않게 가져오게 된다.다른 hiddenid를 .... 상관없나? 그래서 118p의 여러 유연한 검색이 가능하다.
    
    tanh, dtanh 를 왜 쓰는지 모르겠다.
      전방전파, 역전파학습은....좋으나... 왜 그렇게 계산하는지는 모르겠다...
    
    
  전방전파 : uv(url value) 데이터를 찾는게 목적이다.
  역전파학습 : ws, us(word score, url score) 를 찾는게 목적이다.
    클래스 내부 변수를 선언 및 이용하기 위해서, 역전파학습 알고리즘 내부에 전방전파 알고리즘도 포함되어 있다.
    
  mynet.trainquery 할 시 전방전파+역전파학습으로 DB에 strength 값을 갱신하고,
  mynet.getresult 로 DB의 strength 값을 이용해서 결과를 도출한다.
    
    
'''
# 역전파 관련...

from math import tanh

from myutil import consolePrintWithLineNumber as c

import sqlite3 as sqlite


def dtanh(y):
  return 1.0 - y * y
    
class searchnet:
    # 데이터베이스 설정 109p
    def __init__(self, dbname):
      self.con = sqlite.connect(dbname)
  
    def __del__(self):
      self.con.close()

    def maketables(self):
      self.con.execute('drop table if exists hiddennode') 
      self.con.execute('drop table if exists wordhidden') 
      self.con.execute('drop table if exists hiddenurl')
       
      self.con.execute('create table hiddennode(create_key)')
      self.con.execute('create table wordhidden(fromid,toid,strength)')
      self.con.execute('create table hiddenurl(fromid,toid,strength)')
      self.con.commit()
      
    def getstrength(self, fromid, toid, layer):
      if layer == 0: table = 'wordhidden'
      else: table = 'hiddenurl'
      
      res = self.con.execute('select strength from %s where fromid=%d and toid=%d' % (table, fromid, toid)).fetchone()
      if res == None: 
        if layer == 0: return -0.2    # 단어층-은닉층
        if layer == 1: return 0    # 은닉층-출력층
      return res[0]
    
    def setstrength(self, fromid, toid, layer, strength):
      if layer == 0: table = 'wordhidden'
      else: table = 'hiddenurl'
      
      res = self.con.execute('select rowid from %s where fromid=%d and toid=%d' % (table, fromid, toid)).fetchone()
      if res == None: 
        sql = 'insert into %s (fromid,toid,strength) values (%d,%d,%f)' % (table, fromid, toid, strength)
        self.con.execute(sql)
      else:
        rowid = res[0]
        self.con.execute('update %s set strength=%f where rowid=%d' % (table, strength, rowid))
        
    def generatehiddennode(self, wordids, urls):
      if len(wordids) > 3: return None

      sorted_words = [str(id) for id in wordids]
      sorted_words.sort()
      createkey = '_'.join(sorted_words)
      res = self.con.execute("select rowid from hiddennode where create_key='%s'" % createkey).fetchone()

      if res == None:
        cur = self.con.execute("insert into hiddennode (create_key) values ('%s')" % createkey)
        hiddenid = cur.lastrowid

        for wordid in wordids:
          self.setstrength(wordid, hiddenid, layer=0, strength=1.0 / len(wordids))
        for urlid in urls:
          self.setstrength(hiddenid, urlid, layer=1, strength=0.1)
        self.con.commit()

    # 전방전파 112p
    def getallhiddenids(self, wordids, urlids):
      dic = {}
      for wordid in wordids:
        cur = self.con.execute('select toid from wordhidden where fromid=%d' % wordid)    # 단어층에서 은닉번호를 찾고
        for row in cur: 
          dic[row[0]] = 1
      for urlid in urlids:
        cur = self.con.execute('select fromid from hiddenurl where toid=%d' % urlid)    # 출력층에서 은닉번호를 찾는다.
        for row in cur: 
          dic[row[0]] = 1
      return dic.keys()
    
    # 내부 데이터가 어떻게 변하는지 확인하기 위해
    def checkvalue(self):
      c(self.wordids, self.hiddenids, self.urlids)
      c(self.wv, self.hv, self.uv)
      c(self.ws, self.us)
    
    def setupnetwork(self, wordids, urlids):
      self.wordids = wordids
      self.hiddenids = self.getallhiddenids(wordids, urlids)    # 은닉층 id
      self.urlids = urlids
      
      # 노드들의 기본 출력값
      # word value, hidden value, url value
      self.wv = [1.0] * len(self.wordids)    # 단어층 [1.0, 1.0]
      self.hv = [1.0] * len(self.hiddenids)    # 은닉층 [1.0]
      self.uv = [1.0] * len(self.urlids)    # 출력층 [1.0, 1.0, 1.0]
      
      # word strength, url strength (DB데이터)
      self.ws = [[self.getstrength(wordid, hiddenid, 0) for hiddenid in self.hiddenids] for wordid in self.wordids]    # 단어층 강도strength 행렬
      self.us = [[self.getstrength(hiddenid, urlid, 1) for urlid in self.urlids] for hiddenid in self.hiddenids]    # 출력층 강도 행렬
      
#       self.checkvalue()
      '''
      ws:    
        [[0.5],
         [0.5]]
      us:
        [[0.1, 0.1, 0.1]]
      '''
        
    def feedforward(self):
      for i in range(len(self.wordids)):    # 단어층
        self.wv[i] = 1.0
      
      # 은닉 노드 활성화
      for j in range(len(self.hiddenids)):    # 은닉층
        sum = 0.0
        for i in range(len(self.wordids)):
          sum = sum + self.wv[i] * self.ws[i][j]    # 단어의 출력값 * 단어의 strength
        self.hv[j] = tanh(sum)
      
      # 출력 노드 활성화
      for k in range(len(self.urlids)):    # 출력층
        sum = 0.0
        for j in range(len(self.hiddenids)):
          sum = sum + self.hv[j] * self.us[j][k]
        self.uv[k] = tanh(sum)
      
      return self.uv[:]    # self.uv 랑 같은 표현        
        
    def getresult(self, wordids, urlids):
      self.setupnetwork(wordids, urlids)
      return self.feedforward()        
            
    # 역전파 학습 114p
    def backPropagate(self, clickurls, N=0.5):
      # 출력 오류 계산
      output_deltas = [0.0] * len(self.urlids)
      for k in range(len(self.urlids)):
        error = clickurls[k] - self.uv[k]
        output_deltas[k] = dtanh(self.uv[k]) * error
        
      # 은닉층의 오류 계산
      hidden_deltas = [0.0] * len(self.hiddenids)
      for j in range(len(self.hiddenids)):
        error = 0.0
        for k in range(len(self.urlids)):
          error = error + output_deltas[k] * self.us[j][k]
        hidden_deltas[j] = dtanh(self.hv[j]) * error
        
      # 출력 가중치 갱신
      for j in range(len(self.hiddenids)):
        for k in range(len(self.urlids)):
          change = output_deltas[k] * self.hv[j]
          self.us[j][k] = self.us[j][k] + N * change
      # 입력 가중치 갱신
      for i in range(len(self.wordids)):
        for j in range(len(self.hiddenids)):
          change = hidden_deltas[j] * self.wv[i]
          self.ws[i][j] = self.ws[i][j] + N * change
      
    def trainquery(self, wordids, urlids, selectedurl):
      # 사전 작업 
      self.generatehiddennode(wordids, urlids)
      self.setupnetwork(wordids, urlids)
      self.feedforward()
      
      clickurls = [0.0] * len(urlids)    # 선택하지 않은 url은 0이 되도록
      clickurls[urlids.index(selectedurl)] = 1.0    # 선택한 url은 1.0이 되도록
      error = self.backPropagate(clickurls)
      self.updatedatabase()

    '''
          소스 일부 수정 파이썬 2.x -> 3.4 때문에
          에러 메시지 : 'dict_keys' object does not support indexing
          참고 :http://stackoverflow.com/questions/17322668/xi-xj-xj-xi-typeerror-dict-keys-object-does-not-support-indexing
          원인 : 파이썬3에서는 dictionary.keys() 함수를 사용하면 dict_keys 라는 오브젝트를 반환한다. 이것을 바로 사용하지는 못하고, list(.)함수로 묶어줘야 꺼내 사용할 수 있다. 
      >
          소스 수정 : self.hiddenids[j] -> list(self.hiddenids)[j]
    '''
    def updatedatabase(self):
      for i in range(len(self.wordids)):
        for j in range(len(self.hiddenids)):
          self.setstrength(self.wordids[i], list(self.hiddenids)[j], 0, self.ws[i][j])
      for j in range(len(self.hiddenids)):
        for k in range(len(self.urlids)):
          self.setstrength(list(self.hiddenids)[j], self.urlids[k], 1, self.us[j][k])
      self.con.commit()    
      
    ##########################################################################
    
# 111p 데이터베이스 설정(DB에 데이터 초기값 저장)
def practice0():
  mynet = searchnet('nn.db')
  mynet.maketables()
  wWorld, wRiver, wBank, wApple, wAndroid, wWindows = 101, 102, 103, 104, 105, 106
  uWorldBank, uRiver, uEarth, uIphone, uGoogle, uMs = 201, 202, 203, 204, 205, 206
  mynet.generatehiddennode([wWorld, wBank], [uWorldBank, uRiver, uEarth])
  mynet.generatehiddennode([wApple, wAndroid, wWindows], [uIphone, uGoogle, uMs])
  c([r for r in mynet.con.execute('select * from wordhidden')])    # .fetchall() 안해도 되네??
  c([r for r in mynet.con.execute('select * from hiddenurl')])
# practice0()

# 114p 전방 전파(DB에 영향 없음)
def practice1():
  practice0()
  
  mynet = searchnet('nn.db')
  wWorld, wRiver, wBank, wApple, wAndroid, wWindows = 101, 102, 103, 104, 105, 106
  uWorldBank, uRiver, uEarth, uIphone, uGoogle, uMs = 201, 202, 203, 204, 205, 206
  c(mynet.getresult([wWorld, wBank], [uWorldBank, uRiver, uEarth]))
#   c(mynet.getresult([wApple, wAndroid, wWindows], [uIphone, uGoogle, uMs]))
# practice1()

# 117p 역전파 학습(DB에 저장)
def practice2():
  practice0()
  
  mynet = searchnet('nn.db')
  wWorld, wRiver, wBank = 101, 102, 103
  uWorldBank, uRiver, uEarth = 201, 202, 203
  mynet.trainquery([wWorld, wBank], [uWorldBank, uRiver, uEarth], uWorldBank)
  c(mynet.getresult([wWorld, wBank], [uWorldBank, uRiver, uEarth]))
# practice2()
    
# 117p 학습 테스트
def practice3():

  mynet = searchnet('nn.db')
  mynet.maketables()
  wWorld, wRiver, wBank = 101, 102, 103
  uWorldBank, uRiver, uEarth = 201, 202, 203  
  allurls = [uWorldBank, uRiver, uEarth]
  
  # 학습을 많이 할 수록 strength 값이 분명해진다.
  for i in range(1):
    mynet.trainquery([wWorld, wBank], allurls, uWorldBank)    # 101_103으로 검색해서 나온 결과 중, 201을 클릭
    mynet.trainquery([wRiver, wBank], allurls, uRiver)    # 102_103으로 검색해서 나온 결과 중, 202를 클릭
    mynet.trainquery([wWorld], allurls, uEarth)
    
  c(mynet.getresult([wWorld, wBank], allurls))                     
  c(mynet.getresult([wRiver, wBank], allurls))
  c(mynet.getresult([wBank], allurls))                     
practice3()














282 >  [0.28413531533496345, 0.3141065454874244, 0.4577475993646154]
283 >  [0.13376582338538368, 0.35679063543142087, -0.11281807502885814]
284 >  [0.18592887432136948, 0.2778315828040204, 0.02544613051187923]

관련 ERD

 

 

 

7절 클릭 학습

전방전파,역전파학습

ws : word strength

wv : word value

hv : hidden value

us : url strength

uv : url value

각 데이터가 갱신될 때 의존관계 dependency 를 나열하였다.