第6课 深度学习相关比赛中程序问题


运行如下程序
def predict_next(input_array):
    x = numpy.reshape(input_array, (1, seq_length, 1))
    x = x / float(n_vocab)
    y = model.predict(x)
    return y

def string_to_index(raw_input):
    res = []
    for c in raw_input[(len(raw_input)-seq_length):]:
        res.append(char_to_int[c])
    return res

def y_to_char(y):
    largest_index = y.argmax()
    c = int_to_char[largest_index]
    return c

def generate_article(init, rounds=500):
    in_string = init.lower()
    for i in range(rounds):
        n = y_to_char(predict_next(string_to_index(in_string)))
        in_string += n
    return in_string

init = 'Professor Michael S. Hart is the originator of the Project'
article = generate_article(init)
print(article)


----------------------------------------------------------------------------
出现如下问题:
ValueError                                Traceback (most recent call last)
<ipython-input-13-0deb69fb8b05> in <module>()
      1 init = 'Professor Michael S. Hart is the originator of the Project'
----> 2 article = generate_article(init)
      3 print(article)

<ipython-input-12-2b0f33322d1b> in generate_article(init, rounds)
      2     in_string = init.lower()
      3     for i in range(rounds):
----> 4         n = y_to_char(predict_next(string_to_index(in_string)))
      5         in_string += n
      6     return in_string

<ipython-input-11-6d1024d85642> in predict_next(input_array)
      1 def predict_next(input_array):
----> 2     x = numpy.reshape(input_array, (1, seq_length, 1))
      3     x = x / float(n_vocab)
      4     y = model.predict(x)
      5     return y

C:\Users\wang\Anaconda2\lib\site-packages\numpy\core\fromnumeric.pyc in reshape(a, newshape, order)
    221         reshape = a.reshape
    222     except AttributeError:
--> 223         return _wrapit(a, 'reshape', newshape, order=order)
    224     return reshape(newshape, order=order)
    225 

C:\Users\wang\Anaconda2\lib\site-packages\numpy\core\fromnumeric.pyc in _wrapit(obj, method, *args, **kwds)
     45     except AttributeError:
     46         wrap = None
---> 47     result = getattr(asarray(obj), method)(*args, **kwds)
     48     if wrap:
     49         if not isinstance(result, mu.ndarray):

ValueError: total size of new array must be unchanged


请教
已邀请:

要回复问题请先登录注册