分享
三行代码  ›  专栏  ›  技术社区  ›  miho

如何在Tensorflow 2.2中训练多输入的Keras模型?

  •  1
  • miho  · 技术社区  · 5 天前

    Tensorflow documentation about models with multiple inputs

    import tensorflow as tf
    from tensorflow.keras import Input, Model, models, layers
    
    
    def build_model():
        input1 = Input(shape=(50,), dtype=tf.int32, name='x1')
        input2 = Input(shape=(1,), dtype=tf.float32, name='x2')
        y1 = layers.Embedding(1000, 10, input_length=50)(input1)
        y1 = layers.Flatten()(y1)
        y = layers.Concatenate(axis=1)([y1, input2])
        y = layers.Dense(1)(y)
        return Model(inputs=[input1, input2], outputs=y)
    

    建立这种模式也很好:

    model = build_model()
    model.compile(loss='mse')
    model.summary()
    

    summary() this gist .

    def make_dummy_data():
        X1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([100, 50], maxval=1000, dtype=tf.int32))
        X2 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([100, 1], dtype=tf.float32))
        X = tf.data.Dataset.zip((X1, X2)).map(lambda x1, x2: {'x1': x1, 'x2': x2})
        y_true = tf.data.Dataset.from_tensor_slices(tf.random.uniform([100, 1], dtype=tf.float32))
        return X, y_true
    
    
    X, y_true = make_dummy_data()
    Xy = tf.data.Dataset.zip((X, y_true))
    model.fit(Xy, batch_size=32)
    

    ……但现在 fit() 失败,并显示不可理解的错误消息(请参阅 full message here

    WARNING:tensorflow:Model was constructed with shape (None, 50) for input Tensor("x1:0", shape=(None, 50), dtype=int32), but it was called on an input with incompatible shape (50, 1).
    

    嗯,1号的额外尺寸是从哪里来的?还有,我该怎么摆脱它呢?

    还有一件事:通过删除 Embedding -图层确实会突然使模型运行。

    如果你想玩弄上面的样品,我准备了 a notebook on Google Colab for it . 感谢任何帮助。

    1 回复  |  直到 5 天前
        1
  •  1
  •   jdehesa    5 天前

    作为 fit 国家:

    batch_size
    整数或 None . 每次渐变更新的采样数。如果未指定, 批量大小 将默认为32。不指定 批量大小 如果您的数据是以数据集、生成器或 keras.utils.Sequence 实例(因为它们生成批处理)。

    (50, 1)

    你可以这样简单地修复它:

    Xy = tf.data.Dataset.zip((X, y_true)).batch(32)
    model.fit(Xy)