# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import pytest
import functools
from devtools_testutils.aio import recorded_by_proxy_async
from devtools_testutils import set_bodiless_matcher
from azure.core.exceptions import HttpResponseError
from azure.ai.formrecognizer._models import CustomFormModel
from azure.ai.formrecognizer.aio import FormTrainingClient
from preparers import FormRecognizerPreparer
from asynctestcase import AsyncFormRecognizerTest
from preparers import GlobalClientPreparer as _GlobalClientPreparer


FormTrainingClientPreparer = functools.partial(_GlobalClientPreparer, FormTrainingClient)


class TestTrainingAsync(AsyncFormRecognizerTest):

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    @recorded_by_proxy_async
    async def test_training_with_labels_v2(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        model_dict = model.to_dict()
        model = CustomFormModel.from_dict(model_dict)

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    @recorded_by_proxy_async
    async def test_training_multipage_with_labels_v2(self, client, formrecognizer_multipage_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(formrecognizer_multipage_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name


    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    @recorded_by_proxy_async
    async def test_training_without_labels_v2(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        model_dict = model.to_dict()
        model = CustomFormModel.from_dict(model_dict)

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    @recorded_by_proxy_async
    async def test_training_multipage_without_labels_v2(self, client, formrecognizer_multipage_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(formrecognizer_multipage_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    @recorded_by_proxy_async
    async def test_training_with_files_filter_v2(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=False, include_subfolders=True)
            model = await poller.result()
            assert len(model.training_documents) == 6
            assert model.training_documents[-1].name == "subfolder/Form_6.jpg"  # we traversed subfolders

            poller = await client.begin_training(formrecognizer_storage_container_sas_url_v2, use_training_labels=False, prefix="subfolder", include_subfolders=True)
            model = await poller.result()
            assert len(model.training_documents) == 1
            assert model.training_documents[0].name == "subfolder/Form_6.jpg"  # we filtered for only subfolders

            with pytest.raises(HttpResponseError) as e:
                poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=False, prefix="xxx")
                model = await poller.result()
            assert e.value.error.code
            assert e.value.error.message

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.1"})
    @recorded_by_proxy_async
    async def test_training_with_labels_v21(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=True, model_name="my labeled model")
            model = await poller.result()

        model_dict = model.to_dict()
        model = CustomFormModel.from_dict(model_dict)

        assert model.model_id
        assert model.model_name == "my labeled model"
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.1"})
    @recorded_by_proxy_async
    async def test_training_multipage_with_labels_v21(self, client, formrecognizer_multipage_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(formrecognizer_multipage_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.1"})
    @recorded_by_proxy_async
    async def test_training_without_labels_v21(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=True, model_name="my labeled model")
            model = await poller.result()

        model_dict = model.to_dict()
        model = CustomFormModel.from_dict(model_dict)

        assert model.model_id
        assert model.model_name == "my labeled model"
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.1"})
    @recorded_by_proxy_async
    async def test_training_multipage_without_labels_v21(self, client, formrecognizer_multipage_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(formrecognizer_multipage_storage_container_sas_url_v2, use_training_labels=True)
            model = await poller.result()

        assert model.model_id
        assert model.training_started_on
        assert model.training_completed_on
        assert model.errors == []
        assert model.status == "ready"
        for doc in model.training_documents:
            assert doc.name
            assert doc.page_count
            assert doc.status
            assert doc.errors == []
        for sub in model.submodels:
            assert sub.form_type
            assert sub.accuracy is not None
            for key, field in sub.fields.items():
                assert field.accuracy is not None
                assert field.name

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.1"})
    @recorded_by_proxy_async
    async def test_training_with_files_filter_v21(self, client, formrecognizer_storage_container_sas_url_v2, **kwargs):
        set_bodiless_matcher()
        async with client:
            poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=False, include_subfolders=True)
            model = await poller.result()
            assert len(model.training_documents) == 6
            assert model.training_documents[-1].name == "subfolder/Form_6.jpg"  # we traversed subfolders

            poller = await client.begin_training(formrecognizer_storage_container_sas_url_v2, use_training_labels=False, prefix="subfolder", include_subfolders=True)
            model = await poller.result()
            assert len(model.training_documents) == 1
            assert model.training_documents[0].name == "subfolder/Form_6.jpg"  # we filtered for only subfolders

            with pytest.raises(HttpResponseError) as e:
                poller = await client.begin_training(training_files_url=formrecognizer_storage_container_sas_url_v2, use_training_labels=False, prefix="xxx")
                model = await poller.result()
            assert e.value.error.code
            assert e.value.error.message

    @FormRecognizerPreparer()
    @FormTrainingClientPreparer(client_kwargs={"api_version": "2.0"})
    async def test_training_with_model_name_bad_api_version(self, **kwargs):
        client = kwargs.pop("client")
        with pytest.raises(ValueError) as excinfo:
            async with client:
                poller = await client.begin_training(training_files_url="url", use_training_labels=True, model_name="not supported in v2.0")
                result = await poller.result()
        assert "'model_name' is only available for API version V2_1 and up" in str(excinfo.value)
