본문 바로가기
Spark/Spark와 머신 러닝

Spark 시작하기04 - [Spark와 머신 러닝] 3장 스파크를 이용한 데이터 수집, 프로세싱, 준비

by java개발자 2016. 4. 3.

python으로 짜여진 소스코드를 java8 로 작성하였다.

콘솔 로그 출력하기에 유용한 기능을 유틸로 만들었다.

package org.test.sparkNmachineLearning3;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.log4j.PropertyConfigurator;

public class LogUtil {

	List<String> currentFile = null;
	
	public boolean isDebug = true;
	private int stackTraceDepth;
	
	public void setDebug(boolean isDebug){
		this.isDebug = isDebug;
	}
	public LogUtil(boolean isDebug, int stackTraceDepth, String filePath){
		this.isDebug = isDebug;
		this.stackTraceDepth = stackTraceDepth;
		
		//log4j 설정(로그레벨 WARN)... 일단 경로는 하드코딩-_-;;
		PropertyConfigurator.configure("D:\\workspace\\spark\\sparkNmachineLearning\\src\\resources\\log4j.properties");
		
		if(currentFile == null){
			currentFile = new ArrayList<String>();
			BufferedReader br = null;
			try {
			
				br = new BufferedReader(new FileReader(filePath));
				String temp = null;
				currentFile.add("");	//0번째 자리 빈값으로...
				while((temp = br.readLine()) != null){
					int index = temp.indexOf("=");
					if(index != -1){
						currentFile.add(temp.substring(0, index+1).trim() + " ");
					}else{
						currentFile.add("");
					}
				}
			} catch (IOException e) {
				e.printStackTrace();
			} finally{
				try {
					br.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
	}
	/**
	 * 변수에 값을 할당할때, 사용하면, 일단 화면에 한번 출력해주고, 할당한다.
	 * StackTraceElement를 사용하면, 어느 라인에서 할당했는지 알 수 있음.
	 * 디버그용으로 좋음!!
	 */
	public <T> T s(T o){	
		if(isDebug){
			Thread th = Thread.currentThread();
			StackTraceElement[] lists = th.getStackTrace();
//			int lineNumber = lists[2].getLineNumber();
			int lineNumber = lists[this.stackTraceDepth].getLineNumber();
			String variable = currentFile.get(lineNumber);
			System.out.print("\n" +  + lineNumber + ": " + variable + o);
		}
		return o;
	}
}
소스코드
package org.test.sparkNmachineLearning3;

import java.util.List;
import java.util.Map;

import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.test.JavaGraph.MyHistogramPlot;
import org.test.JavaGraph.MySimpleBarPlot;

import scala.Tuple2;

public class Ch3_JavaApp {
	
	LogUtil lu = new LogUtil(true, 3, "D:\\workspace\\spark\\sparkNmachineLearning\\src\\main\\java\\org\\test\\sparkNmachineLearning3\\Ch3_JavaApp.java");
	JavaSparkContext sc = new JavaSparkContext("local[2]", "First Spark App");
	
	public static void main(String...strings){
		Long start = System.currentTimeMillis();
		Ch3_JavaApp c = new Ch3_JavaApp();
//		c.proc1();	//사용자 데이터
//		c.proc2();	//영화 데이터
		c.proc3();	//평점 데이터
		
		Long end = System.currentTimeMillis();
		System.out.println(end - start + " ms spend...");
	}
	public <T> T s(T o){
		return lu.s(o);
	}
	public void proc1(){
		/**
		 * 사용자 데이터 집합 탐색
		 */
//		1|24|M|technician|85711
//		2|53|F|other|94043
//		3|23|M|writer|32067
//		4|24|M|technician|43537
//		5|33|F|other|15213
        JavaRDD<String> user_data = sc.textFile("D:\\workspace\\spark\\sparkNmachineLearning\\src\\main\\java\\org\\test\\sparkNmachineLearning3\\data\\ml-100k\\u.user");
        s(user_data.first());
        
        JavaRDD<String[]> user_fields = user_data.map(a -> a.split("\\|"));	//단순히 |를 쓰면, character 1개씩 쪼개진다.
        JavaRDD<String> users = user_fields.map(a -> a[0]);
        s(users.count());
        s(users.take(13));
        JavaRDD<String> genders = user_fields.map(a -> a[2]).distinct();
        s(genders.count());
        s(genders.collect());
        JavaRDD<String> occupations = user_fields.map(a ->a[3]).distinct();
        s(occupations.count());
        s(occupations.collect());
        JavaRDD<String> zipcodes = user_fields.map(a -> a[4]).distinct();
        s(zipcodes.count());
        
        JavaDoubleRDD ages = user_fields.mapToDouble(a -> Double.parseDouble(a[1]));
        List<Double> ages_list = s(ages.collect());
        //사용자 나이의 분포 - 히스토그램 그래프
		new MyHistogramPlot(2.0, 40, ages_list);
		
		int method = 2;	// or 2
		if(method == 1){
			//방법1
			JavaPairRDD<String,Integer> count_by_occupation = user_fields.mapToPair(a -> new Tuple2<String, Integer>(a[3],1));
			count_by_occupation = count_by_occupation.reduceByKey((a,b) -> a+b);
			//숫자 기준으로 정렬위해(성능은???) - false:내림차순
			count_by_occupation = count_by_occupation.mapToPair(x->x.swap()).sortByKey(false).mapToPair(x->x.swap());	
			List<Tuple2<String, Integer>> list_occupation = s(count_by_occupation.take(20));	//상위 20개
			
			//Tuple2는 스칼라 클래스라서 그래프 유틸쪽에 넘길 수가 없다..
			String[] xNameList = new String[list_occupation.size()];
			Long[] yList = new Long[list_occupation.size()];
			for(int i = 0 ; i<list_occupation.size() ; i++){
				xNameList[i] = list_occupation.get(i)._1;
				yList[i] = (long)list_occupation.get(i)._2;
			}
			//사용자 직업의 분포 xy 그래프
			new MySimpleBarPlot().initArray2(xNameList, yList);
		}else if(method == 2){
			//방법2 - countByValue 사용하기
			// countByValue를 사용하면, 더이상의 RDD action을 하지 못하고, java collection으로 빠져나온다-_-;
			JavaRDD<String> count_by_occupation = user_fields.map(a -> a[3]);
			Map<String, Long> map = count_by_occupation.countByValue();	//how to sort??
			
			//사용자 직업의 분포 xy 그래프
			new MySimpleBarPlot().initMap2(map);
		}
	}
	public void proc2(){
		/**
		 * 영화 데이터 집합 탐색
		 */
//		1|Toy Story (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Toy%20Story%20(1995)|0|0|0|1|1|1|0|0|0|0|0|0|0|0|0|0|0|0|0
//		2|GoldenEye (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?GoldenEye%20(1995)|0|1|1|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0
//		3|Four Rooms (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Four%20Rooms%20(1995)|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|0|0
//		4|Get Shorty (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Get%20Shorty%20(1995)|0|1|0|0|0|1|0|0|1|0|0|0|0|0|0|0|0|0|0
//		5|Copycat (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Copycat%20(1995)|0|0|0|0|0|0|1|0|1|0|0|0|0|0|0|0|1|0|0
        JavaRDD<String> movie_data = sc.textFile("D:\\workspace\\spark\\sparkNmachineLearning\\src\\main\\java\\org\\test\\sparkNmachineLearning3\\data\\ml-100k\\u.item");
        s(movie_data.first());	// 1|Toy Story (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Toy%20Story%20(1995)|0|0|0|1|1|1|0|0|0|0|0|0|0|0|0|0|0|0|0
        s(movie_data.count());	// 1682
        
        JavaRDD<String[]> movie_fields = movie_data.map(a -> a.split("\\|"));
        JavaRDD<String> years = s(movie_fields.map(a -> a[2]));
        years = years.map(str -> {
        	try{
        		return str.substring(str.length()-4);
        	}catch(Exception e){
//        		s(e.getMessage());		//serializable exception 발생.
        		System.out.println(e.getMessage());
        		return "1900";
        	}
        });
        JavaRDD<String> years_filtered = years.filter(a -> !a.equals("1900"));
//        Map<Integer, Long> movie_ages = s(years_filtered.map(a -> 1998 - Integer.parseInt(a)).countByValue());	//sort를 할 수 없다.-_-;
        JavaRDD<Integer> movie_ages = years_filtered.map(a -> 1998 - Integer.parseInt(a));
        JavaPairRDD<Integer, Integer> years_filtered2 = movie_ages.mapToPair(a -> new Tuple2<Integer, Integer>(a, 1));
        years_filtered2 = years_filtered2.reduceByKey((a,b) -> a+b);
        years_filtered2 = years_filtered2.mapToPair(x->x.swap()).sortByKey(true).mapToPair(x->x.swap());		//key오름차순
		List<Tuple2<Integer, Integer>> list = years_filtered2.collect();
		
		Integer[] xNameList = new Integer[list.size()];
		Long[] yList = new Long[list.size()];
		for(int i = 0 ; i<list.size() ; i++){
			xNameList[i] = list.get(i)._1;
			yList[i] = (long)list.get(i)._2;
		}
        //영화 나이의 분포 xy 그래프
		new MySimpleBarPlot().initArray1(xNameList, yList);
	}
	
	public void proc3(){
		/**
		 * 평점 데이터 집합 탐색
		 */
//		196	242	3	881250949
//		186	302	3	891717742
//		22	377	1	878887116
//		244	51	2	880606923
//		166	346	1	886397596
        JavaRDD<String> rating_data_raw = sc.textFile("D:\\workspace\\spark\\sparkNmachineLearning\\src\\main\\java\\org\\test\\sparkNmachineLearning3\\data\\ml-100k\\u.data");
        s(rating_data_raw.first());	//	196	242	3	881250949
        int num_ratings = s((int)rating_data_raw.count());	//	100000
        
        JavaRDD<String[]> rating_data = rating_data_raw.map(s -> s.split("\t"));
        JavaRDD<Integer> ratings = rating_data.map(a -> Integer.parseInt(a[2]));
		int max_rating = s(ratings.reduce((a,b) -> a > b ? a : b));
		int min_rating = s(ratings.reduce((a,b) -> a < b ? a : b));
		float mean_rating = s(ratings.reduce((a,b) -> a+b) / (float)num_ratings);
		
		//TODO
		//spark 통계 함수 Statistics... 그런데 Vector 데이터여야 한다.-_-;

		// 평점의 분포
		Map<Integer, Long> count_by_rating = s(ratings.countByValue());
		new MySimpleBarPlot().initMap1(count_by_rating, true, true);
		
		//사용자별 평점의 개수 - 이것만을 위한다면 굳이 책에서 나오는 것처럼 복잡하게 할 필요가 없다.(위와 같은 방법으로.. 가능)
		JavaRDD<Integer> user_ratings = rating_data.map(a -> Integer.parseInt(a[0]));
		Map<Integer, Long> count_by_user_ratings = s(user_ratings.countByValue());
		new MySimpleBarPlot().initMap1(count_by_user_ratings, false, false);
	}
}


그래프는 java의 그래프 라이브러리 중... GRAL 을 이용하였다. http://trac.erichseifert.de/gral/wiki/Comparison



<사용자 나이의 분포>


3장은 파이썬으로 작성되었기에..

파이썬을 JAVA8 소스로 작성하면서 파이썬의 라이브러리와 비슷한 것을 찾는 것이 쉽지는 않았다.


java의 GRAL 그래프 라이브러리가 꽤 이쁘다...

내부적으로 java의 swing을 이용한다. 속도도 나쁘지 않다.

라이브러리 API 도 꽤 직관적이라서 사용하는데 무리는 없다.


gral-examples 을 변형해서 갖가지 그래프를 만들어 낼 수 있다.

https://github.com/eseifert/gral




<사용자 직업의 분포>





<영화 나이의 분포>



<평점의 분포>



<사용자마다 준 평점 분포>