관리 메뉴

솜씨좋은장씨

[핸즈온머신러닝] 127페이지 MNIST 코드 - ValueError: The number of classes has to be greater than one; got 1 class 해결방법 본문

머신러닝 | 딥러닝/머신러닝 | 딥러닝

[핸즈온머신러닝] 127페이지 MNIST 코드 - ValueError: The number of classes has to be greater than one; got 1 class 해결방법

솜씨좋은장씨 2020. 4. 7. 16:23
728x90
반응형

핸즈온 머신러닝 127 페이지의 이진 분류기 훈련 코드를 실습해보던 중

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(max_iter=5, random_state=42)
sgd_clf.fit(X_train, y_train_5)

코드를 실행하니 다음과 같은 오류를 얻게 되었습니다.

ValueError: The number of classes has to be greater than one; got 1 class

 

원인

 

[핸즈온머신러닝] 124페이지 MNIST 코드 - ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets' 해결방법

핸즈온 머신러닝 124페이지의 MNIST 코드를 실습해보던 중 from sklearn.datasets import fetch_mldata mnist = fetch_mldata('MNIST original') mnist MNIST 데이터를 import 하는 과정에서 -------------------..

somjang.tistory.com

이전에 MNIST 코드를 가져올때 fetch_mldata( ) 가 제대로 동작하지 않아

fetch_openml( )로 변경하여 데이터를 가져왔는데

바뀐 방법에 따라 가져오는 데이터의 형식도 달라져 문제가 생기는 것이었습니다.

 

해결방법은 두가지 입니다.

 

해결방법 1

데이터를 가져올 때 아래의 방법을 활용하여 fetch_mldata( )때와 동일한 형식으로 가져오도록 하는 방법

def sort_by_target(mnist):
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
    mnist.data[:60000] = mnist.data[reorder_train]
    mnist.target[:60000] = mnist.target[reorder_train]
    mnist.data[60000:] = mnist.data[reorder_test + 60000]
    mnist.target[60000:] = mnist.target[reorder_test + 60000]

try:
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, cache=True)
    mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
    sort_by_target(mnist) # fetch_openml() returns an unsorted dataset
except ImportError:
    from sklearn.datasets import fetch_mldata
    mnist = fetch_mldata('MNIST original')

 

해결방법 2

데이터를 fetch_openml( ) 방법으로 가져온 후 astype(np.int8) 로 형식을 변경하는 방법

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)

X, y = mnist['data'], mnist['target']

y = y.astype(np.int8)

 

위의 방법을 거치고 나면

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(max_iter=5, random_state=42)
sgd_clf.fit(X_train, y_train_5)
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
              early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
              l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=5,
              n_iter_no_change=5, n_jobs=None, penalty='l2', power_t=0.5,
              random_state=42, shuffle=True, tol=0.001, validation_fraction=0.1,
              verbose=0, warm_start=False)

이상없이 잘 실행되는 것을 볼 수 있습니다.

728x90
반응형
0 Comments
댓글쓰기 폼