관리 메뉴

솜씨좋은장씨

[핸즈온머신러닝] 124페이지 MNIST 코드 - ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets' 해결방법 본문

머신러닝 | 딥러닝/머신러닝 | 딥러닝

[핸즈온머신러닝] 124페이지 MNIST 코드 - ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets' 해결방법

솜씨좋은장씨 2020. 4. 7. 15:51
728x90
반응형

핸즈온 머신러닝 124페이지의 MNIST 코드를 실습해보던 중

from sklearn.datasets import fetch_mldata

mnist = fetch_mldata('MNIST original')

mnist

MNIST 데이터를 import 하는 과정에서

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-3-11eb5a5519e0> in <module>
----> 1 from sklearn.datasets import fetch_mldata
      2 
      3 mnist = fetch_mldata('MNIST original')
      4 
      5 mnist

ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets' 
(C:\Users\users\anaconda3\lib\site-packages\sklearn\datasets\__init__.py)

다음과 같은 오류가 발생하여 찾아보니 scikit-learn 0.20 이후 부터는

fetch_mldata( )는 더이상 지원하지 않는 것을 알 수 있었습니다.

 

fetch_mldata( ) 대신 fetch_openml( ) 이라는 이름으로 변경되었습니다.

 

따라서 앞서 오류가 났던 코드를

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist

위와 같이 변경하면 MNIST 데이터를 불러옵니다.

 

그런데 여기서 기존의 fetch_mldata( )는 target을 기준으로 정렬된 데이터를 제공하였으나

fetch_openml( ) 이 제공하는 데이터는 그렇지 않다고 합니다.

 

그리고 데이터의 형식도 unit8과 float64로 서로 달라 추후 모델을 학습할 때 오류가 날 수 있습니다.

 

fetch_mldata( )와 동일한 결과를 얻고 싶다면 아래의 방법을 참고하여 실행하면 됩니다.

def sort_by_target(mnist):
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
    mnist.data[:60000] = mnist.data[reorder_train]
    mnist.target[:60000] = mnist.target[reorder_train]
    mnist.data[60000:] = mnist.data[reorder_test + 60000]
    mnist.target[60000:] = mnist.target[reorder_test + 60000]

try:
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, cache=True)
    mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
    sort_by_target(mnist) # fetch_openml() returns an unsorted dataset
except ImportError:
    from sklearn.datasets import fetch_mldata
    mnist = fetch_mldata('MNIST original')

 

참고링크

 

mnist dataset · Issue #301 · ageron/handson-ml

hi, I just noticed that the mnist dataset was removed from the sklearn and tensorflow basic datasets. therefore, it brings the trouble in doing the example of the chapter 3. I will be grateful if y...

github.com

Comments