Notice
Recent Posts
Recent Comments
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
Tags
- leetcode
- 맥북
- SW Expert Academy
- ubuntu
- Git
- 데이콘
- 프로그래머스
- AI 경진대회
- github
- Docker
- 캐치카페
- ChatGPT
- 파이썬
- 자연어처리
- 편스토랑
- dacon
- programmers
- gs25
- Real or Not? NLP with Disaster Tweets
- 우분투
- 편스토랑 우승상품
- 프로그래머스 파이썬
- 백준
- Kaggle
- hackerrank
- PYTHON
- Baekjoon
- 더현대서울 맛집
- 금융문자분석경진대회
- 코로나19
Archives
- Today
- Total
솜씨좋은장씨
[핸즈온머신러닝] 127페이지 MNIST 코드 - ValueError: The number of classes has to be greater than one; got 1 class 해결방법 본문
머신러닝 | 딥러닝/머신러닝 | 딥러닝
[핸즈온머신러닝] 127페이지 MNIST 코드 - ValueError: The number of classes has to be greater than one; got 1 class 해결방법
솜씨좋은장씨 2020. 4. 7. 16:23728x90
반응형
핸즈온 머신러닝 127 페이지의 이진 분류기 훈련 코드를 실습해보던 중
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(max_iter=5, random_state=42)
sgd_clf.fit(X_train, y_train_5)
코드를 실행하니 다음과 같은 오류를 얻게 되었습니다.
ValueError: The number of classes has to be greater than one; got 1 class
원인
이전에 MNIST 코드를 가져올때 fetch_mldata( ) 가 제대로 동작하지 않아
fetch_openml( )로 변경하여 데이터를 가져왔는데
바뀐 방법에 따라 가져오는 데이터의 형식도 달라져 문제가 생기는 것이었습니다.
해결방법은 두가지 입니다.
해결방법 1
데이터를 가져올 때 아래의 방법을 활용하여 fetch_mldata( )때와 동일한 형식으로 가져오도록 하는 방법
def sort_by_target(mnist):
reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
mnist.data[:60000] = mnist.data[reorder_train]
mnist.target[:60000] = mnist.target[reorder_train]
mnist.data[60000:] = mnist.data[reorder_test + 60000]
mnist.target[60000:] = mnist.target[reorder_test + 60000]
try:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
sort_by_target(mnist) # fetch_openml() returns an unsorted dataset
except ImportError:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
해결방법 2
데이터를 fetch_openml( ) 방법으로 가져온 후 astype(np.int8) 로 형식을 변경하는 방법
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
X, y = mnist['data'], mnist['target']
y = y.astype(np.int8)
위의 방법을 거치고 나면
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(max_iter=5, random_state=42)
sgd_clf.fit(X_train, y_train_5)
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=5,
n_iter_no_change=5, n_jobs=None, penalty='l2', power_t=0.5,
random_state=42, shuffle=True, tol=0.001, validation_fraction=0.1,
verbose=0, warm_start=False)
이상없이 잘 실행되는 것을 볼 수 있습니다.
'머신러닝 | 딥러닝 > 머신러닝 | 딥러닝' 카테고리의 다른 글
Comments