/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.internal.processors.rest;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.ignite.configuration.ConnectorConfiguration;
import org.apache.ignite.configuration.IgniteConfiguration;
import org.apache.ignite.internal.client.marshaller.jdk.GridClientJdkMarshaller;
import org.apache.ignite.internal.processors.rest.client.message.GridClientHandshakeRequest;
import org.apache.ignite.internal.processors.rest.client.message.GridClientMessage;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.internal.util.lang.GridAbsPredicate;
import org.apache.ignite.internal.util.typedef.internal.U;
import org.apache.ignite.testframework.GridTestUtils;
import org.apache.ignite.testframework.junits.WithSystemProperty;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.junit.Test;

import static org.apache.ignite.IgniteSystemProperties.IGNITE_ENABLE_OBJECT_INPUT_FILTER_AUTOCONFIGURATION;
import static org.apache.ignite.IgniteSystemProperties.IGNITE_MARSHALLER_BLACKLIST;
import static org.apache.ignite.IgniteSystemProperties.IGNITE_MARSHALLER_WHITELIST;
import static org.apache.ignite.internal.processors.rest.protocols.tcp.GridMemcachedMessage.IGNITE_HANDSHAKE_FLAG;
import static org.apache.ignite.internal.processors.rest.protocols.tcp.GridMemcachedMessage.IGNITE_REQ_FLAG;

/**
 * Tests for whitelist and blacklist ot avoiding deserialization vulnerability.
 */
@WithSystemProperty(key = IGNITE_ENABLE_OBJECT_INPUT_FILTER_AUTOCONFIGURATION, value = "false")
public class TcpRestUnmarshalVulnerabilityTest extends GridCommonAbstractTest {
    /** Marshaller. */
    private static final GridClientJdkMarshaller MARSH = new GridClientJdkMarshaller();

    /** Shared value. */
    private static final AtomicBoolean SHARED = new AtomicBoolean();

    /** Port. */
    private static int port;

    /** Host. */
    private static String host;

    /** {@inheritDoc} */
    @Override protected IgniteConfiguration getConfiguration(String igniteInstanceName) throws Exception {
        IgniteConfiguration cfg = super.getConfiguration(igniteInstanceName);

        ConnectorConfiguration connCfg = new ConnectorConfiguration();

        port = connCfg.getPort();
        host = connCfg.getHost();

        cfg.setConnectorConfiguration(connCfg);

        return cfg;
    }

    /** {@inheritDoc} */
    @Override protected void beforeTest() throws Exception {
        super.beforeTest();

        SHARED.set(false);

        System.clearProperty(IGNITE_MARSHALLER_WHITELIST);
        System.clearProperty(IGNITE_MARSHALLER_BLACKLIST);

        IgniteUtils.clearClassCache();
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testNoLists() throws Exception {
        testExploit(true);
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testWhiteListIncluded() throws Exception {
        String path = U.resolveIgnitePath("modules/core/src/test/config/class_list_exploit_included.txt").getPath();

        System.setProperty(IGNITE_MARSHALLER_WHITELIST, path);

        testExploit(true);
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testWhiteListExcluded() throws Exception {
        String path = U.resolveIgnitePath("modules/core/src/test/config/class_list_exploit_excluded.txt").getPath();

        System.setProperty(IGNITE_MARSHALLER_WHITELIST, path);

        testExploit(false);
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testBlackListIncluded() throws Exception {
        String path = U.resolveIgnitePath("modules/core/src/test/config/class_list_exploit_included.txt").getPath();

        System.setProperty(IGNITE_MARSHALLER_BLACKLIST, path);

        testExploit(false);
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testBlackListExcluded() throws Exception {
        String path = U.resolveIgnitePath("modules/core/src/test/config/class_list_exploit_excluded.txt").getPath();

        System.setProperty(IGNITE_MARSHALLER_BLACKLIST, path);

        testExploit(true);
    }

    /**
     * @throws Exception If failed.
     */
    @Test
    public void testBothListIncluded() throws Exception {
        String path = U.resolveIgnitePath("modules/core/src/test/config/class_list_exploit_included.txt").getPath();

        System.setProperty(IGNITE_MARSHALLER_WHITELIST, path);
        System.setProperty(IGNITE_MARSHALLER_BLACKLIST, path);

        testExploit(false);
    }

    /**
     * @param positive Positive.
     */
    private void testExploit(boolean positive) throws Exception {
        try {
            startGrid();

            attack(marshal(new Exploit()).array());

            boolean res = GridTestUtils.waitForCondition(new GridAbsPredicate() {
                @Override public boolean apply() {
                    return SHARED.get();
                }
            }, 3000L);

            if (positive)
                assertTrue(res);
            else
                assertFalse(res);
        }
        finally {
            stopAllGrids();
        }
    }

    /**
     * @param obj Object.
     */
    private static ByteBuffer marshal(Object obj) throws IOException {
        return MARSH.marshal(obj, 0);
    }

    /**
     * @param data Data.
     */
    private void attack(byte[] data) throws IOException {
        InetAddress addr = InetAddress.getByName(host);

        try (
            Socket sock = new Socket(addr, port);
            OutputStream os = new BufferedOutputStream(sock.getOutputStream())
        ) {
            // Handshake request.
            os.write(IGNITE_HANDSHAKE_FLAG);

            GridClientHandshakeRequest req = new GridClientHandshakeRequest();
            req.marshallerId(GridClientJdkMarshaller.ID);
            os.write(req.rawBytes());
            os.flush();

            // Handshake response
            InputStream is = new BufferedInputStream(sock.getInputStream());

            is.read(new byte[146]); // Read handshake response.

            int len = data.length + 40;

            os.write(IGNITE_REQ_FLAG); // Package type.
            os.write((byte)(len >> 24)); // Package length.
            os.write((byte)(len >> 16));
            os.write((byte)(len >> 8));
            os.write((byte)(len));
            os.write(new byte[40]); // Stream header.
            os.write(data); // Exploit.
            os.flush();
        }
    }

    /** */
    private static class Exploit implements GridClientMessage {
        /**
         * @param is Input stream.
         */
        private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
            SHARED.set(true);
        }

        /** {@inheritDoc} */
        @Override public long requestId() {
            return 0;
        }

        /** {@inheritDoc} */
        @Override public void requestId(long reqId) {
            // No-op.
        }

        /** {@inheritDoc} */
        @Override public UUID clientId() {
            return null;
        }

        /** {@inheritDoc} */
        @Override public void clientId(UUID id) {
            // No-op.
        }

        /** {@inheritDoc} */
        @Override public UUID destinationId() {
            return null;
        }

        /** {@inheritDoc} */
        @Override public void destinationId(UUID id) {
            // No-op.
        }

        /** {@inheritDoc} */
        @Override public byte[] sessionToken() {
            return new byte[0];
        }

        /** {@inheritDoc} */
        @Override public void sessionToken(byte[] sesTok) {
            // No-op.
        }
    }
}
