From b3f138fa1069a6ec680e9777424b695224810eaf Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Thu, 3 Jul 2025 17:22:15 -0400 Subject: [PATCH] fix: resolve mypy type errors and editorconfig issues - Add type annotations for json_output and agent_thoughts variables - Fix trailing whitespace in modified files - Update .gitignore to not ignore api/core directory --- .editorconfig | 1 + .gitignore | 1 - CHANGELOG.md | 71 ++ LICENSE | 43 +- README.md | 697 ++++++++++++------ api/.gitignore | 11 + api/app.py | 4 - api/app_factory.py | 2 - api/configs/app_config.py | 5 - api/configs/deploy/__init__.py | 4 - api/configs/enterprise/__init__.py | 1 - api/configs/extra/notion_config.py | 4 - api/configs/extra/sentry_config.py | 2 - api/configs/feature/__init__.py | 97 --- .../feature/hosted_service/__init__.py | 19 - api/configs/middleware/__init__.py | 21 - api/configs/middleware/cache/redis_config.py | 16 - .../storage/aliyun_oss_storage_config.py | 6 - .../storage/amazon_s3_storage_config.py | 6 - .../storage/azure_blob_storage_config.py | 3 - .../storage/baidu_obs_storage_config.py | 3 - .../storage/google_cloud_storage_config.py | 1 - .../storage/huawei_obs_storage_config.py | 3 - .../middleware/storage/oci_storage_config.py | 4 - .../storage/supabase_storage_config.py | 2 - .../storage/tencent_cos_storage_config.py | 4 - .../storage/volcengine_tos_storage_config.py | 4 - .../middleware/vdb/baidu_vector_config.py | 6 - api/configs/middleware/vdb/chroma_config.py | 5 - .../middleware/vdb/couchbase_config.py | 4 - .../middleware/vdb/elasticsearch_config.py | 3 - .../middleware/vdb/huawei_cloud_config.py | 2 - api/configs/middleware/vdb/milvus_config.py | 6 - api/configs/middleware/vdb/myscale_config.py | 5 - .../middleware/vdb/oceanbase_config.py | 5 - .../middleware/vdb/opengauss_config.py | 7 - .../middleware/vdb/opensearch_config.py | 8 - api/configs/middleware/vdb/oracle_config.py | 6 - api/configs/middleware/vdb/pgvector_config.py | 7 - .../middleware/vdb/pgvectors_config.py | 4 - api/configs/middleware/vdb/qdrant_config.py | 5 - api/configs/middleware/vdb/relyt_config.py | 4 - .../middleware/vdb/tablestore_config.py | 3 - .../middleware/vdb/tencent_vector_config.py | 8 - .../middleware/vdb/tidb_on_qdrant_config.py | 11 - .../middleware/vdb/tidb_vector_config.py | 4 - api/configs/middleware/vdb/upstash_config.py | 1 - .../middleware/vdb/vastbase_vector_config.py | 6 - api/configs/middleware/vdb/vikingdb_config.py | 6 - api/configs/middleware/vdb/weaviate_config.py | 3 - api/configs/observability/otel/otel_config.py | 11 - .../apollo/__init__.py | 3 - .../remote_settings_sources/apollo/client.py | 8 - .../apollo/python_3x.py | 3 - .../remote_settings_sources/nacos/__init__.py | 4 - .../nacos/http_request.py | 4 - .../remote_settings_sources/nacos/utils.py | 7 - api/constants/__init__.py | 6 - api/constants/languages.py | 2 - api/constants/mimetypes.py | 1 - api/constants/tts_auto_play_timeout.py | 1 - api/contexts/__init__.py | 7 - api/contexts/wrapper.py | 4 - api/controllers/__init__.py | 1 - api/controllers/common/fields.py | 2 - api/controllers/common/helpers.py | 8 - api/controllers/console/__init__.py | 9 - api/controllers/console/admin.py | 23 - api/controllers/console/apikey.py | 13 - .../console/app/advanced_prompt_template.py | 1 - api/controllers/console/app/agent.py | 2 - api/controllers/console/app/annotation.py | 19 - api/controllers/console/app/app.py | 38 - api/controllers/console/app/app_import.py | 6 - api/controllers/console/app/audio.py | 6 - api/controllers/console/app/completion.py | 12 - api/controllers/console/app/conversation.py | 42 -- .../console/app/conversation_variables.py | 4 - api/controllers/console/app/generator.py | 7 - api/controllers/console/app/message.py | 24 - api/controllers/console/app/model_config.py | 16 - api/controllers/console/app/ops_trace.py | 4 - api/controllers/console/app/site.py | 9 - api/controllers/console/app/statistic.py | 104 --- api/controllers/console/app/workflow.py | 83 --- .../console/app/workflow_app_log.py | 4 - .../console/app/workflow_draft_variable.py | 21 - api/controllers/console/app/workflow_run.py | 8 - .../console/app/workflow_statistic.py | 51 -- api/controllers/console/app/wraps.py | 9 - api/controllers/console/auth/activate.py | 8 - .../console/auth/data_source_bearer_auth.py | 2 - .../console/auth/data_source_oauth.py | 5 - .../console/auth/forgot_password.py | 21 - api/controllers/console/auth/login.py | 20 - api/controllers/console/auth/oauth.py | 19 - api/controllers/console/billing/billing.py | 2 - api/controllers/console/billing/compliance.py | 2 - .../console/datasets/data_source.py | 7 - api/controllers/console/datasets/datasets.py | 52 -- .../console/datasets/datasets_document.py | 105 --- .../console/datasets/datasets_segments.py | 22 - api/controllers/console/datasets/external.py | 23 - .../console/datasets/hit_testing.py | 2 - .../console/datasets/hit_testing_base.py | 3 - api/controllers/console/datasets/metadata.py | 9 - api/controllers/console/explore/audio.py | 5 - api/controllers/console/explore/completion.py | 14 - .../console/explore/conversation.py | 13 - .../console/explore/installed_app.py | 18 - api/controllers/console/explore/message.py | 14 - api/controllers/console/explore/parameter.py | 6 - .../console/explore/recommended_app.py | 4 - .../console/explore/saved_message.py | 10 - api/controllers/console/explore/workflow.py | 5 - api/controllers/console/explore/wraps.py | 9 - api/controllers/console/extension.py | 13 - api/controllers/console/files.py | 7 - api/controllers/console/init_validate.py | 5 - api/controllers/console/remote_files.py | 7 - api/controllers/console/setup.py | 5 - api/controllers/console/tag/tags.py | 14 - api/controllers/console/version.py | 5 - api/controllers/console/workspace/__init__.py | 5 - api/controllers/console/workspace/account.py | 43 -- .../console/workspace/agent_providers.py | 2 - api/controllers/console/workspace/endpoint.py | 25 - .../workspace/load_balancing_config.py | 17 - api/controllers/console/workspace/members.py | 12 - .../console/workspace/model_providers.py | 24 - api/controllers/console/workspace/models.py | 48 -- api/controllers/console/workspace/plugin.py | 59 -- .../console/workspace/tool_providers.py | 91 --- .../console/workspace/workspace.py | 29 - api/controllers/console/wraps.py | 16 - api/controllers/files/__init__.py | 2 - api/controllers/files/image_preview.py | 14 - api/controllers/files/tool_files.py | 7 - api/controllers/files/upload.py | 11 - api/controllers/inner_api/__init__.py | 1 - api/controllers/inner_api/mail.py | 1 - api/controllers/inner_api/plugin/plugin.py | 1 - api/controllers/inner_api/plugin/wraps.py | 15 - .../inner_api/workspace/workspace.py | 9 - api/controllers/inner_api/wraps.py | 13 - api/controllers/service_api/__init__.py | 1 - api/controllers/service_api/app/annotation.py | 5 - api/controllers/service_api/app/app.py | 4 - api/controllers/service_api/app/audio.py | 4 - api/controllers/service_api/app/completion.py | 15 - .../service_api/app/conversation.py | 10 - api/controllers/service_api/app/file.py | 6 - api/controllers/service_api/app/message.py | 8 - api/controllers/service_api/app/site.py | 3 - api/controllers/service_api/app/workflow.py | 12 - .../service_api/dataset/dataset.py | 57 -- .../service_api/dataset/document.py | 56 -- .../service_api/dataset/hit_testing.py | 2 - .../service_api/dataset/metadata.py | 9 - .../service_api/dataset/segment.py | 30 - .../service_api/dataset/upload_file.py | 1 - .../service_api/workspace/models.py | 2 - api/controllers/service_api/wraps.py | 32 - api/controllers/web/__init__.py | 3 - api/controllers/web/app.py | 15 - api/controllers/web/audio.py | 4 - api/controllers/web/completion.py | 14 - api/controllers/web/conversation.py | 13 - api/controllers/web/files.py | 6 - api/controllers/web/forgot_password.py | 20 - api/controllers/web/login.py | 13 - api/controllers/web/message.py | 14 - api/controllers/web/passport.py | 20 - api/controllers/web/remote_files.py | 7 - api/controllers/web/saved_message.py | 9 - api/controllers/web/site.py | 7 - api/controllers/web/workflow.py | 5 - api/controllers/web/wraps.py | 12 - api/core/agent/base_agent_runner.py | 51 -- api/core/agent/cot_agent_runner.py | 39 - api/core/agent/cot_chat_agent_runner.py | 11 - api/core/agent/cot_completion_agent_runner.py | 9 - api/core/agent/fc_agent_runner.py | 49 -- .../agent/output_parser/cot_output_parser.py | 20 - api/core/agent/plugin_entities.py | 2 - api/core/agent/prompt/template.py | 24 - api/core/agent/strategy/plugin.py | 2 - api/core/file/file_manager.py | 8 - .../helper/code_executor/code_executor.py | 14 - .../javascript/javascript_transformer.py | 3 - .../jinja2/jinja2_transformer.py | 9 - .../python3/python3_transformer.py | 4 - .../code_executor/template_transformer.py | 1 - api/core/helper/download.py | 1 - api/core/helper/encrypter.py | 1 - api/core/helper/marketplace.py | 1 - api/core/helper/model_provider_cache.py | 4 - api/core/helper/moderation.py | 8 - api/core/helper/position_helper.py | 7 - api/core/helper/ssrf_proxy.py | 9 - api/core/helper/tool_parameter_cache.py | 3 - api/core/helper/tool_provider_cache.py | 4 - api/core/helper/url_signer.py | 4 - api/core/hosting_configuration.py | 26 - api/core/indexing_runner.py | 54 -- api/core/llm_generator/llm_generator.py | 60 -- .../output_parser/structured_output.py | 38 - .../suggested_questions_after_answer.py | 1 - api/core/llm_generator/prompts.py | 29 - api/core/memory/token_buffer_memory.py | 21 - .../model_runtime/callbacks/base_callback.py | 4 - .../callbacks/logging_callback.py | 13 - .../model_runtime/entities/llm_entities.py | 5 - .../entities/message_entities.py | 8 - .../model_runtime/entities/model_entities.py | 3 - .../entities/provider_entities.py | 3 - .../model_providers/__base/ai_model.py | 26 - .../__base/large_language_model.py | 33 - .../__base/moderation_model.py | 3 - .../model_providers/__base/rerank_model.py | 1 - .../__base/speech2text_model.py | 2 - .../__base/text_embedding_model.py | 9 - .../__base/tokenizers/gpt2_tokenzier.py | 2 - .../model_providers/__base/tts_model.py | 3 - .../model_providers/model_provider_factory.py | 50 -- .../schema_validators/common_validator.py | 13 - .../model_credential_schema_validator.py | 5 - .../provider_credential_schema_validator.py | 2 - api/core/model_runtime/utils/encoders.py | 5 - api/core/moderation/api/api.py | 11 - api/core/moderation/base.py | 10 - api/core/moderation/factory.py | 3 - api/core/moderation/input_moderation.py | 7 - api/core/moderation/keywords/keywords.py | 12 - .../openai_moderation/openai_moderation.py | 8 - api/core/moderation/output_moderation.py | 18 - api/core/plugin/backwards_invocation/app.py | 16 - .../plugin/backwards_invocation/encrypt.py | 1 - api/core/plugin/backwards_invocation/model.py | 25 - api/core/plugin/backwards_invocation/node.py | 4 - api/core/plugin/backwards_invocation/tool.py | 2 - api/core/plugin/entities/parameters.py | 5 - api/core/plugin/entities/plugin.py | 1 - api/core/plugin/entities/plugin_daemon.py | 1 - api/core/plugin/entities/request.py | 4 - api/core/plugin/impl/agent.py | 11 - api/core/plugin/impl/base.py | 11 - api/core/plugin/impl/debugging.py | 1 - api/core/plugin/impl/dynamic_select.py | 2 - api/core/plugin/impl/endpoint.py | 1 - api/core/plugin/impl/model.py | 25 - api/core/plugin/impl/oauth.py | 8 - api/core/plugin/impl/plugin.py | 3 - api/core/plugin/impl/tool.py | 19 - api/core/provider_manager.py | 122 --- api/core/tools/__base/tool.py | 15 - api/core/tools/__base/tool_provider.py | 16 - api/core/tools/builtin_tool/provider.py | 17 - .../builtin_tool/providers/_positions.py | 1 - .../builtin_tool/providers/audio/tools/asr.py | 2 - .../builtin_tool/providers/audio/tools/tts.py | 3 - .../providers/code/tools/simple_code.py | 3 - .../providers/time/tools/current_time.py | 1 - .../time/tools/localtime_to_timestamp.py | 2 - .../time/tools/timestamp_to_localtime.py | 3 - .../time/tools/timezone_conversion.py | 1 - .../providers/time/tools/weekday.py | 3 - .../providers/webscraper/tools/webscraper.py | 2 - api/core/tools/builtin_tool/tool.py | 13 - api/core/tools/custom_tool/provider.py | 13 - api/core/tools/custom_tool/tool.py | 27 - api/core/tools/entities/api_entities.py | 1 - api/core/tools/entities/file_entities.py | 1 - api/core/tools/entities/tool_entities.py | 9 - api/core/tools/entities/values.py | 2 - api/core/tools/plugin_tool/provider.py | 3 - api/core/tools/plugin_tool/tool.py | 5 - api/core/tools/signature.py | 4 - api/core/tools/tool_engine.py | 25 - api/core/tools/tool_file_manager.py | 29 - api/core/tools/tool_label_manager.py | 14 - api/core/tools/tool_manager.py | 84 --- api/core/tools/utils/configuration.py | 29 - .../dataset_multi_retriever_tool.py | 13 - .../dataset_retriever_base_tool.py | 1 - .../dataset_retriever_tool.py | 6 - .../tools/utils/dataset_retriever_tool.py | 6 - api/core/tools/utils/message_transformer.py | 7 - .../tools/utils/model_invocation_utils.py | 20 - api/core/tools/utils/parser.py | 46 -- api/core/tools/utils/text_processing_utils.py | 2 - api/core/tools/utils/web_reader_tool.py | 13 - .../utils/workflow_configuration_sync.py | 5 - api/core/tools/utils/yaml_utils.py | 1 - api/core/tools/workflow_as_tool/provider.py | 21 - api/core/tools/workflow_as_tool/tool.py | 20 - api/core/variables/segments.py | 1 - api/core/variables/types.py | 4 - api/core/workflow/nodes/agent/agent_node.py | 9 +- api/core/workflow/nodes/tool/tool_node.py | 6 +- api/factories/agent_factory.py | 1 - api/factories/file_factory.py | 39 - api/factories/variable_factory.py | 15 - api/migrations/env.py | 3 +- ...ef91f18_rename_api_provider_description.py | 2 +- ...5ba0e_add_workflow_tool_label_and_tool_.py | 2 +- ...9b_update_appmodelconfig_and_add_table_.py | 2 +- .../053da0c1d756_add_api_tool_privacy.py | 2 +- ...84c228_remove_tool_id_from_model_invoke.py | 2 +- ...c1af8d_add_dataset_permission_tenant_id.py | 2 +- api/migrations/versions/16830a790f0f_.py | 2 +- ...16fa53d9faec_add_provider_model_support.py | 2 +- ...ab037c40_add_keyworg_table_storage_type.py | 2 +- ...442fc_modify_provider_model_name_length.py | 2 +- ...ae959a_update_tools_original_url_length.py | 2 +- ...3fcf12ba_support_conversation_variables.py | 2 +- ...7ff0dc_add_conversations_dialogue_count.py | 2 +- ...0956-0251a1c768cc_add_tidb_auth_binding.py | 2 +- ...001-a6be81136580_app_and_site_icon_type.py | 2 +- ...ename_workflow__conversation_variables_.py | 2 +- ...d_add_created_by_and_updated_by_to_app_.py | 2 +- ...add_use_icon_as_answer_icon_fields_for_.py | 2 +- ...bb251_add_parent_message_id_to_messages.py | 2 +- ...-6af6a521a53e_update_retrieval_resource.py | 2 +- ...434-33f5fac87f29_external_knowledge_api.py | 2 +- ...ase_max_length_of_builtin_tool_provider.py | 2 +- ...744d88ed6_fix_wrong_service_api_history.py | 2 +- ...a11becb_add_name_and_size_to_tool_files.py | 2 +- ..._10_22_0959-43fa78bc3b7d_add_white_list.py | 2 +- ...c4f75af5e_add_tenant_plugin_permisisons.py | 2 +- ...3f6769a94a3_add_upload_files_source_url.py | 2 +- ...ename_conversation_variables_index_name.py | 2 +- ...ce70a7ca_update_upload_files_source_url.py | 2 +- ...pdate_type_of_custom_disclaimer_to_text.py | 2 +- ...9b_update_workflows_graph_features_and_.py | 2 +- ...832f7_add_created_at_index_for_messages.py | 2 +- ...22_0701-e19037032219_parent_child_index.py | 2 +- ...4fc45278_add_exceptions_count_field_to_.py | 2 +- ...b07f66c737_remove_unused_tool_providers.py | 2 +- ...dd_retry_index_field_to_node_execution_.py | 2 +- ..._remove_workflow_node_executions_retry_.py | 2 +- ...52d42eb6_add_auto_disabled_dataset_logs.py | 2 +- ...e_change_workflow_runs_total_tokens_to_.py | 2 +- ...4_0617-f051706725cc_add_rate_limit_logs.py | 2 +- ...0917-d20049ed0af6_add_metadata_function.py | 2 +- ...413929e1ec2_extend_provider_name_column.py | 2 +- ..._add_marked_name_and_marked_comment_in_.py | 2 +- ...315-5511c782ee4c_extend_provider_column.py | 2 +- ..._change_documentsegment_and_childchunk_.py | 2 +- ...72_add_index_for_workflow_conversation_.py | 2 +- ...e1f5dfb_add_workflowdraftvariable_model.py | 2 +- ...w_draft_varaibles_add_node_execution_id.py | 2 +- ...a_remove_sequence_number_from_workflow_.py | 2 +- ...9d_add_message_files_into_agent_thought.py | 2 +- .../246ba09cbbdb_add_app_anntation_setting.py | 2 +- .../versions/2a3aebbbf4bb_add_app_tracing.py | 2 +- .../2beac44e5f5f_add_is_universal_in_apps.py | 2 +- .../2c8af9671032_add_qa_document_language.py | 2 +- ...2e9819ca5b28_add_tenant_id_in_api_token.py | 2 +- ...a5a70d_add_tool_labels_to_agent_thought.py | 2 +- .../3b18fea55204_add_tool_label_bings.py | 2 +- ...3c7cac9521c6_add_tags_and_binding_table.py | 2 +- .../3ef9b2b6bee6_add_assistant_app.py | 2 +- .../408176b91ad3_add_max_active_requests.py | 2 +- ...5564d_conversation_columns_set_nullable.py | 2 +- ...76cc39132_add_annotation_histoiry_score.py | 2 +- ...f8c4f3_modify_default_model_name_length.py | 2 +- .../versions/4823da1d26cf_add_tool_file.py | 2 +- ...fee_change_message_chain_id_to_nullable.py | 2 +- ...d64aa4_update_dataset_model_field_null_.py | 2 +- .../4e99a8df00ff_add_load_balancing.py | 2 +- .../4ff534e1eb11_add_workflow_to_site.py | 2 +- ...022897aaceb_add_model_name_in_embedding.py | 2 +- .../versions/53bf8af60645_update_model.py | 2 +- ...nable_tool_file_without_conversation_id.py | 2 +- .../5fda94355fce_custom_disclaimer.py | 2 +- .../614f77cecc48_add_last_active_at.py | 2 +- .../versions/63f9175e515b_merge_branches.py | 2 +- .../64a70a7aab8b_add_workflow_run_index.py | 2 +- api/migrations/versions/64b051264f32_init.py | 2 +- ...21501b_add_node_execution_id_into_node_.py | 2 +- ...43972bdc_add_dataset_retriever_resource.py | 2 +- ...fb077b04_add_dataset_collection_binding.py | 2 +- ...5b_add_embedding_cache_created_at_index.py | 2 +- ...39_add_anntation_history_match_response.py | 2 +- ...3755c_add_app_config_retriever_resource.py | 2 +- .../7b45942e39bb_add_api_key_auth_binding.py | 2 +- .../7bdef072e63a_add_workflow_tool.py | 2 +- .../7ce5a52e4eee_add_tool_providers.py | 2 +- ...a8693e07a_add_table_dataset_permissions.py | 2 +- .../853f9b9cd3b6_add_message_price_unit.py | 2 +- ...8072f0caa04_add_custom_config_in_tenant.py | 2 +- api/migrations/versions/89c7899ca936_.py | 2 +- ...daa_add_tool_conversation_variables_idx.py | 2 +- .../8d2d099ceb74_add_qa_model_support.py | 2 +- ...e_add_environment_variable_to_workflow_.py | 2 +- ...6f3c800_rename_api_provider_credentails.py | 2 +- .../8fe468ba0ca5_add_gpt4v_supports.py | 2 +- .../968fff4c0ab9_add_api_based_extension.py | 2 +- .../9e98fbaffb88_add_workflow_tool_version.py | 2 +- .../9f4e3427ea84_add_created_by_role.py | 2 +- ...fafbd60eca1_add_message_file_belongs_to.py | 2 +- ...4dfde53b_add_language_to_recommend_apps.py | 2 +- ...56fb053ef_app_config_add_speech_to_text.py | 2 +- ...d7385a7b66_add_embeddings_provider_name.py | 2 +- .../a8f9b3c45e4a_add_tenant_id_db_index.py | 2 +- ...e_add_external_data_tools_in_app_model_.py | 2 +- ...dd_dataset_query_variable_at_app_model_.py | 2 +- .../ad472b61a054_add_api_provider_icon.py | 2 +- api/migrations/versions/b24be59fbb04_.py | 2 +- ...6_add_workflow_run_id_index_for_message.py | 2 +- .../versions/b289e2408ee2_add_workflow.py | 2 +- ...09c049e8e_add_advanced_prompt_templates.py | 2 +- ...29b71023c_messages_columns_set_nullable.py | 2 +- .../b69ca54b9208_add_chatbot_color_theme.py | 2 +- .../bf0aec5ba2cf_add_provider_order.py | 2 +- ...9_remove_app_model_config_trace_config_.py | 2 +- .../versions/c3311b089690_add_tool_meta.py | 2 +- .../c71211c8f604_add_tool_invoke_model_log.py | 2 +- ...998d4d_set_model_config_column_nullable.py | 2 +- ...3a3471c_add_is_deleted_to_conversations.py | 2 +- .../de95f5c77138_migration_serpapi_api_key.py | 2 +- .../versions/dfb3b7f477da_add_tool_index.py | 2 +- .../e1901f623fd0_add_annotation_reply.py | 2 +- .../e2eacc9a1b63_add_status_for_message.py | 2 +- ...08af0a69ccefbb59fa80c778efee300bb780980.py | 2 +- ...ed59becda_modify_quota_limit_field_type.py | 2 +- .../e8883b0148c9_add_dataset_model_name.py | 2 +- ...e349e6ac_increase_max_model_name_length.py | 2 +- .../f25003750af4_add_created_updated_at.py | 2 +- ...85e260_add_anntation_history_message_id.py | 2 +- .../f9107f83abab_add_desc_for_apps.py | 2 +- ...fca025d3b60f_add_dataset_retrival_model.py | 2 +- ..._remove_extra_tracing_app_config_table .py | 2 +- api/models/_workflow_exc.py | 1 - api/models/account.py | 8 - api/models/api_based_extension.py | 1 - api/models/dataset.py | 42 -- api/models/engine.py | 2 - api/models/model.py | 87 --- api/models/provider.py | 11 - api/models/source.py | 2 - api/models/task.py | 2 - api/models/tools.py | 13 - api/models/types.py | 1 - api/models/web.py | 2 - api/models/workflow.py | 101 --- api/pyproject.toml | 12 + api/schedule/clean_embedding_cache_task.py | 1 - api/schedule/clean_messages.py | 1 - api/schedule/clean_unused_datasets_task.py | 11 - api/schedule/create_tidb_serverless_task.py | 2 - .../mail_clean_document_notify_task.py | 6 - api/schedule/queue_monitor_task.py | 5 - .../update_tidb_serverless_status_task.py | 2 - api/services/account_service.py | 116 --- .../advanced_prompt_template_service.py | 5 - api/services/agent_service.py | 15 - api/services/annotation_service.py | 37 - api/services/api_based_extension_service.py | 14 - api/services/app_dsl_service.py | 55 -- api/services/app_generate_service.py | 4 - api/services/app_service.py | 39 - api/services/audio_service.py | 17 - api/services/auth/api_key_auth_service.py | 1 - api/services/auth/firecrawl/firecrawl.py | 1 - api/services/auth/jina.py | 1 - api/services/auth/jina/jina.py | 1 - api/services/auth/watercrawl/watercrawl.py | 1 - api/services/billing_service.py | 12 - .../clear_free_plan_tenant_expired_logs.py | 36 - api/services/conversation_service.py | 22 - api/services/dataset_service.py | 191 ----- api/services/enterprise/base.py | 1 - api/services/enterprise/enterprise_service.py | 8 - .../entities/model_provider_entities.py | 8 - api/services/external_knowledge_service.py | 23 - api/services/feature_service.py | 36 - api/services/file_service.py | 35 - api/services/hit_testing_service.py | 15 - api/services/message_service.py | 43 -- api/services/metadata_service.py | 4 - api/services/model_load_balancing_service.py | 71 -- api/services/model_provider_service.py | 60 -- api/services/operation_service.py | 2 - api/services/ops_service.py | 24 - api/services/plugin/data_migration.py | 23 - api/services/plugin/dependencies_analysis.py | 7 - api/services/plugin/oauth_service.py | 2 - api/services/plugin/plugin_migration.py | 54 -- .../plugin/plugin_parameter_service.py | 6 - .../plugin/plugin_permission_service.py | 2 - api/services/plugin/plugin_service.py | 28 - .../buildin/buildin_retrieval.py | 2 - .../database/database_retrieval.py | 9 - .../recommend_app/remote/remote_retrieval.py | 3 - api/services/recommended_app_service.py | 1 - api/services/saved_message_service.py | 7 - api/services/tag_service.py | 1 - .../tools/api_tools_manage_service.py | 57 -- .../tools/builtin_tools_manage_service.py | 34 - api/services/tools/tools_manage_service.py | 4 - api/services/tools/tools_transform_service.py | 20 - .../tools/workflow_tools_manage_service.py | 36 - api/services/vector_service.py | 6 - api/services/web_conversation_service.py | 8 - api/services/webapp_auth_service.py | 17 - api/services/website_service.py | 5 - api/services/workflow/workflow_converter.py | 60 -- api/services/workflow_app_service.py | 14 - .../workflow_draft_variable_service.py | 30 - api/services/workflow_run_service.py | 19 - api/services/workflow_service.py | 66 -- api/services/workspace_service.py | 5 - .../core/workflow/nodes/tool/__init__.py | 1 - .../nodes/variable_assigner/v2/__init__.py | 1 - .../entities/advanced_prompt_entities.py | 2 - dev/pytest/pytest_config_tests.py | 10 +- docker/docker-compose.middleware.yaml | 2 +- sdks/python-client/dify_client/__init__.py | 2 +- sdks/python-client/dify_client/client.py | 64 +- sdks/python-client/setup.py | 4 +- sdks/python-client/tests/test_client.py | 42 +- ...est_provider_update_deadlock_prevention.py | 52 +- 524 files changed, 713 insertions(+), 6130 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 api/.gitignore diff --git a/.editorconfig b/.editorconfig index 374da0b5d2..7014cc5b45 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,6 +13,7 @@ trim_trailing_whitespace = true [*.py] indent_size = 4 indent_style = space +insert_final_newline = true [*.{yml,yaml}] indent_style = space diff --git a/.gitignore b/.gitignore index 709c157b7d..da6b5e2d24 100644 --- a/.gitignore +++ b/.gitignore @@ -219,4 +219,3 @@ api/.env.backup # custom untracked files venv312/ web/.env.local.save -core/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..aec96dda5e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,71 @@ +# Changelog + +## [3.3.0](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.2.1...v3.3.0) (2025-05-07) + + +### Features + +* add `.jj` (Jujutsu) directory to default exclude list ([#458](https://github.com/editorconfig-checker/editorconfig-checker/issues/458)) ([ac903a0](https://github.com/editorconfig-checker/editorconfig-checker/commit/ac903a0a7f5506a80b3c5d2e76584b5e277b896a)) +* update default paths to exclude ([#462](https://github.com/editorconfig-checker/editorconfig-checker/issues/462)) ([84c5c55](https://github.com/editorconfig-checker/editorconfig-checker/commit/84c5c5579e96a9601f1b0ce51fec66257ceb0b24)) + + +### Bug Fixes + +* skip correct number of errors when consolidating errors ([#464](https://github.com/editorconfig-checker/editorconfig-checker/issues/464)) ([8c695f5](https://github.com/editorconfig-checker/editorconfig-checker/commit/8c695f5ef82063d657796dfc0b58e35b022d4b93)) + +## [3.2.1](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.2.0...v3.2.1) (2025-03-15) + + +### Bug Fixes + +* check for exclusion before MIME type ([#447](https://github.com/editorconfig-checker/editorconfig-checker/issues/447)) ([cd9976b](https://github.com/editorconfig-checker/editorconfig-checker/commit/cd9976ba25738a02a2130a7fc5e729ed9d6b7251)) +* empty format in the config file should be treated as Default ([#448](https://github.com/editorconfig-checker/editorconfig-checker/issues/448)) ([f8799d0](https://github.com/editorconfig-checker/editorconfig-checker/commit/f8799d0915e6c7a3c82941c14b5bafcf472283cf)), closes [#430](https://github.com/editorconfig-checker/editorconfig-checker/issues/430) +* **test:** make TestGetRelativePath work under Darwin ([#445](https://github.com/editorconfig-checker/editorconfig-checker/issues/445)) ([d956561](https://github.com/editorconfig-checker/editorconfig-checker/commit/d95656138c991c47847015902c75f46aeccb8d06)) +* **test:** support running our test suite under `-trimpath`, closes [#397](https://github.com/editorconfig-checker/editorconfig-checker/issues/397) ([#439](https://github.com/editorconfig-checker/editorconfig-checker/issues/439)) ([fc78406](https://github.com/editorconfig-checker/editorconfig-checker/commit/fc78406ae4d64dc63256c5b37db61b770bf5e436)) +* **test:** we no longer need -ldflags at all ([#444](https://github.com/editorconfig-checker/editorconfig-checker/issues/444)) ([9ffcae2](https://github.com/editorconfig-checker/editorconfig-checker/commit/9ffcae2b7d984c6bf48fde83aaf55ab8962a927a)) + +## [3.2.0](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.1.2...v3.2.0) (2025-01-25) + + +### Features + +* add support for env var NO_COLOR ([#429](https://github.com/editorconfig-checker/editorconfig-checker/issues/429)) ([9135f53](https://github.com/editorconfig-checker/editorconfig-checker/commit/9135f531e762ad4c02f4bf45f03888771773da56)) +* only output "0 errors found" when verbose output is enabled ([#423](https://github.com/editorconfig-checker/editorconfig-checker/issues/423)) ([1d29a8b](https://github.com/editorconfig-checker/editorconfig-checker/commit/1d29a8b16b4cde8d46f80db29e60330c5bd16095)) + + +### Bug Fixes + +* improve default excludes ([#427](https://github.com/editorconfig-checker/editorconfig-checker/issues/427)) ([d0cbd25](https://github.com/editorconfig-checker/editorconfig-checker/commit/d0cbd250caa46a07994b6161ccf2bb4910571a23)) + +## [3.1.2](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.1.1...v3.1.2) (2025-01-10) + + +### Bug Fixes + +* provide both .tar.gz as well as .zip archives ([#416](https://github.com/editorconfig-checker/editorconfig-checker/issues/416)) ([00e9890](https://github.com/editorconfig-checker/editorconfig-checker/commit/00e9890847982b2503ec3a11ff539bf2ac4c34c6)), closes [#415](https://github.com/editorconfig-checker/editorconfig-checker/issues/415) + +## [3.1.1](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.1.0...v3.1.1) (2025-01-08) + + +### Bug Fixes + +* dockerfile expected binary at /, not /usr/bin/ [#410](https://github.com/editorconfig-checker/editorconfig-checker/issues/410) ([#411](https://github.com/editorconfig-checker/editorconfig-checker/issues/411)) ([2c82197](https://github.com/editorconfig-checker/editorconfig-checker/commit/2c821979c0b3ea291f65ec813cae3fa265603528)) + +## [3.1.0](https://github.com/editorconfig-checker/editorconfig-checker/compare/v3.0.3...v3.1.0) (2025-01-06) + + +### Features + +* add zip version when compressing all binaries ([#321](https://github.com/editorconfig-checker/editorconfig-checker/issues/321)) ([#362](https://github.com/editorconfig-checker/editorconfig-checker/issues/362)) ([f1bb625](https://github.com/editorconfig-checker/editorconfig-checker/commit/f1bb625f2553952d4d8c72e3f97d17417f0c1ef7)) +* consolidate adjacent error messages ([#360](https://github.com/editorconfig-checker/editorconfig-checker/issues/360)) ([cf4ae1c](https://github.com/editorconfig-checker/editorconfig-checker/commit/cf4ae1ccede331b2aa1b115f1de5257737de7eef)) +* editorconfig-checker-disable-next-line ([#363](https://github.com/editorconfig-checker/editorconfig-checker/issues/363)) ([6116ec6](https://github.com/editorconfig-checker/editorconfig-checker/commit/6116ec6685b33652e9e25def9b8897ed4b015c7d)) +* provide Codeclimate compatible report fromat ([#367](https://github.com/editorconfig-checker/editorconfig-checker/issues/367)) ([282c315](https://github.com/editorconfig-checker/editorconfig-checker/commit/282c315bd1c48f49cc1328de36e2ba4433c50249)) +* support `.editorconfig-checker.json` config ([#375](https://github.com/editorconfig-checker/editorconfig-checker/issues/375)) ([cb0039c](https://github.com/editorconfig-checker/editorconfig-checker/commit/cb0039cfe68a11139011bcffe84b8ff62b3209bb)) + + +### Bug Fixes + +* actually use the correct end marker ([#405](https://github.com/editorconfig-checker/editorconfig-checker/issues/405)) ([3c03499](https://github.com/editorconfig-checker/editorconfig-checker/commit/3c034994cba21db7babd33672a0d26184ff88255)) +* add `.ecrc` deprecation warning ([#389](https://github.com/editorconfig-checker/editorconfig-checker/issues/389)) ([d33b81c](https://github.com/editorconfig-checker/editorconfig-checker/commit/d33b81cc71c2eb740dd3e1c00f07dbc430b89087)) +* this release-please marker ([#403](https://github.com/editorconfig-checker/editorconfig-checker/issues/403)) ([617c6d4](https://github.com/editorconfig-checker/editorconfig-checker/commit/617c6d44b5a8668de16bf67038dd5930e01c074e)) +* typo in config, `SpacesAftertabs` => `SpacesAfterTabs` ([#386](https://github.com/editorconfig-checker/editorconfig-checker/issues/386)) ([25e3542](https://github.com/editorconfig-checker/editorconfig-checker/commit/25e3542ee45b0bd5cbdd450ba8eebee6ad3bba43)) diff --git a/LICENSE b/LICENSE index 329ee30287..5420ac57f6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,22 +1,21 @@ -# Open Source License - -Dify is licensed under a modified version of the Apache License 2.0, with the following additional conditions: - -1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer: - -a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. - - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - -b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend. - - Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker. - -2. As a contributor, you should agree that: - -a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary. -b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations. - -Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0. - -The interactive design of this product is protected by appearance patent. - -© 2025 LangGenius, Inc. +The MIT License (MIT) + +Copyright (c) 2018 Max Strübing + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 1dc7e2dd98..88f04e8ff5 100644 --- a/README.md +++ b/README.md @@ -1,268 +1,475 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) - -

- 📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast -

- -

- Dify Cloud · - Self-hosting · - Documentation · - Dify edition overview -

- -

- - Static Badge - - Static Badge - - chat on Discord - - join Reddit - - follow on X(Twitter) - - follow on LinkedIn - - Docker Pulls - - Commits last month - - Issues closed - - Discussion posts -

- -

- README in English - 繁體中文文件 - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা -

- -Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. - -## Quick start - -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> -> - CPU >= 2 Core -> - RAM >= 4 GiB - -
- -The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd dify -cd docker -cp .env.example .env -docker compose up -d +# editorconfig-checker + +Buy Me A Coffee + +[![ci](https://github.com/editorconfig-checker/editorconfig-checker/actions/workflows/ci.yml/badge.svg)](https://github.com/editorconfig-checker/editorconfig-checker/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/editorconfig-checker/editorconfig-checker/branch/main/graph/badge.svg)](https://codecov.io/gh/editorconfig-checker/editorconfig-checker) +[![Hits-of-Code](https://hitsofcode.com/github/editorconfig-checker/editorconfig-checker?branch=main&label=Hits-of-Code)](https://hitsofcode.com/github/editorconfig-checker/editorconfig-checker/view?branch=main&label=Hits-of-Code) +[![Go Report Card](https://goreportcard.com/badge/github.com/editorconfig-checker/editorconfig-checker/v3)](https://goreportcard.com/report/github.com/editorconfig-checker/editorconfig-checker/v3) + +![Logo](docs/logo.png) + +1. [What?](#what) +2. [Quickstart](#quickstart) +3. [Installation](#installation) +4. [Usage](#usage) +5. [Configuration](#configuration) +6. [Excluding](#excluding) + 1. [Excluding Lines](#excluding-lines) + 2. [Excluding Blocks](#excluding-blocks) + 3. [Excluding Paths](#excluding-paths) + 1. [Inline](#inline) + 2. [Default Excludes](#default-excludes) + 3. [Ignoring Default Excludes](#ignoring-default-excludes) + 4. [Manually Excluding](#manually-excluding) + 1. [via configuration](#via-configuration) + 2. [via arguments](#via-arguments) +7. [Docker](#docker) +8. [Continuous Integration](#continuous-integration) +9. [Support](#support) +10. [Contributing](#contributing) +11. [Semantic Versioning Policy](#semantic-versioning-policy) + +## What? + +![Example Screenshot](docs/screenshot.png) + +This is a tool to check if your files consider your `.editorconfig` rules. +Most tools—like linters, for example—only test one filetype and need an extra configuration. +This tool only needs your `.editorconfig` to check all files. + +If you don't know about editorconfig already you can read about it here: [editorconfig.org](https://editorconfig.org/). + +Currently, implemented editorconfig features are: + +- `end_of_line` +- `insert_final_newline` +- `trim_trailing_whitespace` +- `indent_style` +- `indent_size` +- `max_line_length` + +Unsupported features are: + +- `charset` + +## Quickstart + + +```shell +VERSION="v3.3.0" +OS="linux" +ARCH="amd64" +curl -O -L -C - https://github.com/editorconfig-checker/editorconfig-checker/releases/download/$VERSION/ec-$OS-$ARCH.tar.gz && \ +tar xzf ec-$OS-$ARCH.tar.gz && \ +./bin/ec-$OS-$ARCH ``` + -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -#### Seeking help - -Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - -## Key features - -**1. Workflow**: -Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. - -**2. Comprehensive model support**: -Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). - -![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) - -**3. Prompt IDE**: -Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. - -**4. RAG Pipeline**: -Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. - -**5. Agent capabilities**: -You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. - -**6. LLMOps**: -Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. - -**7. Backend-as-a-Service**: -All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. - -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
+## Installation -## Using Dify +Grab a binary from the [release page](https://github.com/editorconfig-checker/editorconfig-checker/releases). -- **Cloud
** - We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. +If you have go installed you can run `go get github.com/editorconfig-checker/editorconfig-checker/v3` +and run `make build` inside the project folder. +This will place a binary called `ec` into the `bin` directory. -- **Self-hosting Dify Community Edition
** - Quickly get Dify running in your environment with this [starter guide](#quick-start). - Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - -- **Dify for enterprise / organizations
** - We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
- > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding. - -## Staying ahead - -Star Dify on GitHub and be instantly notified of new releases. - -![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - -## Advanced Setup - -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). - -If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - -- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) -- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) -- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) -- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) - -#### Using Terraform for Deployment - -Deploy Dify to Cloud Platform with a single click using [terraform](https://www.terraform.io/) - -##### Azure Global - -- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) - -##### Google Cloud - -- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) - -#### Using AWS CDK for Deployment - -Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - -##### AWS - -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +If you are using Arch Linux, you can use [pacman](https://wiki.archlinux.org/title/Pacman) to install from [extra repository](https://archlinux.org/packages/extra/x86_64/editorconfig-checker/): -#### Using Alibaba Cloud Computing Nest +```shell +pacman -S editorconfig-checker +``` -Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +Also, development (VCS) package is available in the [AUR](https://aur.archlinux.org/packages/editorconfig-checker-git): -#### Using Alibaba Cloud Data Management +```shell +# editorconfig-checker-git -One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +# i.e. +paru -S editorconfig-checker-git +``` +If Go 1.16 or greater is installed, you can also install it globally via `go install`: -## Contributing +```shell +go install github.com/editorconfig-checker/editorconfig-checker/v3/cmd/editorconfig-checker@latest +``` + +## Usage + +```txt +USAGE: + -config string + config + -debug + print debugging information + -disable-end-of-line + disables the trailing whitespace check + -disable-indent-size + disables only the indent-size check + -disable-indentation + disables the indentation check + -disable-insert-final-newline + disables the final newline check + -disable-trim-trailing-whitespace + disables the trailing whitespace check + -dry-run + show which files would be checked + -exclude string + a regex which files should be excluded from checking - needs to be a valid regular expression + -format + specifies the output format, see "Formats" below for more information + -h print the help + -help + print the help + -ignore-defaults + ignore default excludes + -init + creates an initial configuration + -no-color + disables printing color + -color + enables printing color + -v print debugging information + -verbose + print debugging information + -version + print the version number +``` + +If you run this tool from a repository root it will check all files which are added to the git repository and are text files. If the tool isn't able to determine a file type it will be added to be checked too. + +If you run this tool from a normal directory it will check all files which are text files. If the tool isn't able to determine a file type it will be added to be checked too. + +### Formats + +The following output formats are supported: + +- **default**: Plain text, human readable output.
+ ```text + : + -: + ``` +- **gcc**: GCC compatible output. Useful for editors that support compiling and showing syntax errors.
+ `::: : ` +- **github-actions**: The format used by GitHub Actions
+ `::error file=,line=,endLine=::` +- **codeclimate**: The [Code Climate](https://github.com/codeclimate/platform/blob/master/spec/analyzers/SPEC.md#data-types) json format used for [custom quality reports](https://docs.gitlab.com/ee/ci/testing/code_quality.html#implement-a-custom-tool) in GitLab CI + ```json + [ + { + "check_name": "editorconfig-checker", + "description": "Wrong indent style found (tabs instead of spaces)", + "fingerprint": "e87a958a3960d60a11d4b49c563cccd2", + "severity": "minor", + "location": { + "path": ".vscode/extensions.json", + "lines": { + "begin": 2, + "end": 2 + } + } + } + ] + ``` + +## Configuration + +The configuration is done via arguments or it will take the first config file found with the following file names: + +- `.editorconfig-checker.json` +- `.ecrc` (deprecated filename, soon unsupported) + +A sample configuration file can look like this and will be used from your current working directory if not specified via the `--config` argument: + +```json +{ + "Verbose": false, + "Debug": false, + "IgnoreDefaults": false, + "SpacesAfterTabs": false, + "NoColor": false, + "Exclude": [], + "AllowedContentTypes": [], + "PassedFiles": [], + "Disable": { + "EndOfLine": false, + "Indentation": false, + "IndentSize": false, + "InsertFinalNewline": false, + "TrimTrailingWhitespace": false, + "MaxLineLength": false + } +} +``` -For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. +You can set any of the options under the `"Disable"` section to `true` to disable those particular checks. -> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +You could also specify command line arguments, and they will get merged with the configuration file. The command line arguments have a higher precedence than the configuration. -## Community & contact +You can create a configuration with the `init`-flag. If you specify a `config`-path it will be created there. -- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. -- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. +By default, the allowed_content_types are: -**Contributors** +1. `text/` (matches `text/plain`, `text/html`, etc.) +1. `application/ecmascript` +1. `application/json` +1. `application/x-ndjson` +1. `application/xml` +1. `+json` (matches `application/geo+json`, etc.) +1. `+xml` (matches `application/rss+xml`, etc.) +1. `application/octet-stream` - - - +`application/octet-stream` is needed as a fallback when no content type could be determined. You can add additional accepted content types with the `allowed_content_types` key. But the default ones don't get removed. -## Star history +## Excluding + +### Excluding Lines + +You can exclude single lines inline. To do that you need a comment on that line that says: `editorconfig-checker-disable-line`. + +```javascript +const myTemplateString = ` + first line + wrongly indented line because it needs to be` // editorconfig-checker-disable-line +``` + +Alternatively, you can use `editorconfig-checker-disable-next-line` to skip the line that comes after this comment. +This modifier is present to improve readability, or because your sometimes have no other choice because of your own/language constraints. + +```javascript +// editorconfig-checker-disable-next-line used because blah blah blah what ever the reason blah +const myTemplateString = `a line that is (...) longer (...) than ... usual` // or with a very long inline comment +``` -[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +Please note that using `editorconfig-checker-disable-next-line` has only an effect on the next line, so it will report if the line where you added the modifier doesn't comply. -## Security disclosure +### Excluding Blocks -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. +To temporarily disable all checks, add a comment containing `editorconfig-checker-disable`. Re-enable with a comment containing `editorconfig-checker-enable` -## License +```javascript +// editorconfig-checker-disable +const myTemplateString = ` + first line + wrongly indented line because it needs to be +` +// editorconfig-checker-enable +``` + +### Excluding Paths + +You can exclude paths from being checked in several ways: + +- ignoring a file by documenting it inside the to-be-excluded file +- adding a regex matching the path to the [configuration file](#configuration) +- passing a regex matching the path as argument to `--exclude` + +All these excludes are used in addition to the [default excludes](#default-excludes), unless you [opt out of them](#ignoring-default-excludes). + +If you want to see which files would be checked without checking them you can pass the `--dry-run` flag. + +Note that while `--dry-run` might output absolute paths, the regular expression you write must match the filenames using relative paths from where editorconfig-checker is used. This becomes especially relevant if you need to anchor your regular expression in order to only match files in the top level your checked directory. + +Additionally, paths will be normalized to Unix style before matching against the regex list happens. As a result you don't have to write `[\\/]` to account for Windows and Unix path styles but can just use `/` instead. + +#### Inline + +If you want to exclude a file inline you need a comment on the first line of the file that contains: `editorconfig-checker-disable-file` + +```haskell +-- editorconfig-checker-disable-file +add :: Int -> Int -> Int +add x y = + let result = x + y -- falsy indentation would not report + in result -- falsy indentation would not report +``` + +#### Default Excludes + +If you choose to [ignore them](#ignoring-default-excludes), these paths are excluded automatically: + +```txt +// source control related files and folders +"\\.git/", +"\\.jj/", +// package manager, generated, & lock files +// Cargo (Rust) +"Cargo\\.lock$", +// Composer (PHP) +"composer\\.lock$", +// RubyGems (Ruby) +"Gemfile\\.lock$", +// Go Modules (Go) +"go\\.(mod|sum)$", +// Gradle (Java) +"gradle/wrapper/gradle-wrapper\\.properties$", +"gradlew(\\.bat)?$", +"(buildscript-)?gradle\\.lockfile?$", +// Maven (Java) +"\\.mvn/wrapper/maven-wrapper\\.properties$", +"\\.mvn/wrapper/MavenWrapperDownloader\\.java$", +"mvnw(\\.cmd)?$", +// NodeJS +"/node_modules/", +// npm (NodeJS) +"npm-shrinkwrap\\.json$", +"package-lock\\.json$", +// pip (Python) +"Pipfile\\.lock$", +// Poetry (Python) +"poetry\\.lock$", +// pnpm (NodeJS) +"pnpm-lock\\.yaml$", +// Terraform & OpenTofu +"\\.terraform\\.lock\\.hcl$", +// uv (Python) +"uv\\.lock$", +// yarn (NodeJS) +"\\.pnp\\.c?js$", +"\\.pnp\\.loader\\.mjs$", +"\\.yarn/", +"yarn\\.lock$", +// font files +"\\.eot$", +"\\.otf$", +"\\.ttf$", +"\\.woff2?$", +// image & video formats +"\\.avif$", +"\\.gif$", +"\\.ico$", +"\\.jpe?g$", +"\\.mp4$", +"\\.p[bgnp]m$", +"\\.png$", +"\\.svg$", +"\\.tiff?$", +"\\.webp$", +"\\.wmv$", +// other binary or container formats +"\\.bak$", +"\\.bin$", +"\\.docx?$", +"\\.exe$", +"\\.pdf$", +"\\.snap$", +"\\.xlsx?$", +// archive formats +"\\.7z$", +"\\.bz2$", +"\\.gz$", +"\\.jar$", +"\\.tar$", +"\\.tgz$", +"\\.war$", +"\\.zip$", +// log & (git) patch files +"\\.log$", +"\\.patch$", +// generated or minified CSS and JavaScript files +"\\.(css|js)\\.map$", +"min\\.(css|js)$", +``` + +#### Ignoring Default Excludes + +If you either set `IgnoreDefaults` to `true` or pass the `-ignore-defaults` commandline switch, the [default excludes](#default-excludes) will be ignored entirely. + +#### Manually Excluding + +##### via configuration + +In your [configuration file](#configuration) you can exclude files with the `"exclude"` key which takes an array of regular expressions. +This will get merged with the default excludes (if not [ignored](#ignoring-default-excludes)). You should remember to escape your regular expressions correctly. + +A [configuration file](#configuration) which would ignore all test files and all Markdown files can look like this: + +```json +{ + "Verbose": false, + "IgnoreDefaults": false, + "Exclude": ["testfiles", "\\.md$"], + "SpacesAfterTabs": false, + "Disable": { + "EndOfLine": false, + "Indentation": false, + "IndentSize": false, + "InsertFinalNewline": false, + "TrimTrailingWhitespace": false, + "MaxLineLength": false + } +} +``` + +##### via arguments + +If you want to play around how the tool would behave you can also pass the `--exclude` argument to the binary. This will accept a regular expression as well. The argument given will be added to the excludes as defined by your [configuration file](#configuration) (respecting both its [`Exclude`](#via-configuration) and [`IgnoreDefaults`](#ignoring-default-excludes) settings). + +For example: `ec --exclude node_modules` + +## Docker + +You are able to run this tool inside a Docker container. +To do this you need to have Docker installed and run this command in your repository root which you want to check: +`docker run --rm --volume=$PWD:/check mstruebing/editorconfig-checker` + +Docker Hub: [mstruebing/editorconfig-checker](https://hub.docker.com/r/mstruebing/editorconfig-checker) + +## Continuous Integration + +### Mega-Linter + +Instead of installing and configuring `editorconfig-checker` and all other linters in your project CI workflows (GitHub Actions & others), you can use [Mega-Linter](https://megalinter.io/latest/) which does all that for you with a single [assisted installation](https://megalinter.io/latest/install-assisted/). + +Mega-Linter embeds [editorconfig-checker](https://megalinter.io/latest/descriptors/editorconfig_editorconfig_checker/) by default in all its [flavors](https://megalinter.io/latest/flavors/), meaning that it will be run at each commit or Pull Request to detect any issue related to `.editorconfig`. + +If you want to use only `editorconfig-checker` and not the 70+ other linters, you can use the following `.mega-linter.yml` configuration file: + +```yaml +ENABLE: + - EDITORCONFIG +``` + +### GitLab CI + +The [ss-open/ci/recipes project](https://gitlab.com/ss-open/ci/recipes) offers a ready to use lint job integrating editorconfig-checker. + +- Main documentation: +- Editorconfig job specific documentation: + +## Support + +If you have any questions, suggestions, need a wrapper for a programming language or just want to chat join #editorconfig-checker on freenode(IRC). +If you don't have an IRC-client set up you can use the [freenode webchat](https://webchat.freenode.net/?channels=editorconfig-checker). + +## Contributing -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. +Anyone can help to improve the project, submit a Feature Request, a bug report or even correct a spelling mistake. + +The steps to contribute can be found in the [CONTRIBUTING.md](./CONTRIBUTING.md) file. + +## Semantic Versioning Policy + +**editorconfig-checker** adheres to [Semantic Versioning](https://semver.org/) for releases. + +However, as it is a code quality tool, it's not always clear when a minor or major version bump occurs. The following rules are used to determine the version bump: + +- Patch release (1.0.x -> 1.0.y) + - Updates to output formats (error messages, logs, ...). + - Performance improvements which doesn't affect behavior. + - Build process changes (e.g., updating dependencies, updating `Dockerfile`, ...). + - Reverts (reverting a previous commit). + - Bug fixes which result in **editorconfig-checker** reporting less linting errors (removing "false-positive" linting errors). +- Minor release (1.x.0 -> 1.y.0) + - Adding new [configuration options](#configuration), including new CLI flags. + - Adding new [path to exclude by default](#default-excludes). + - Adding new [output formats](#formats). + - Supporting a new [editorconfig](https://editorconfig.org/) property (e.g: `insert_final_newline`, `indent_size`, ...). + - Any new feature which doesn't break existing behavior. +- Major release (x.0.0 -> y.0.0) + - Removal of a [configuration](#configuration) option. + - Removal of an [output format](#formats). + - Removal of a [path to exclude by default](#default-excludes). + - Removal of support for an [editorconfig](https://editorconfig.org/) property. + - Bug fixes, which result in **editorconfig-checker** reporting more linting errors, because the previous behavior was incorrect according to the [editorconfig specification](https://editorconfig.org/). diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000000..8a7ff53960 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,11 @@ +# venv +venv*/ +.venv/ + +# ide +.idea/ +.vscode/ + +# custom +web/.env.local.save +venv312/ diff --git a/api/app.py b/api/app.py index 4f393f6c20..3149ccb464 100644 --- a/api/app.py +++ b/api/app.py @@ -22,20 +22,16 @@ else: # gevent monkey.patch_all() - from grpc.experimental import gevent as grpc_gevent # type: ignore # grpc gevent grpc_gevent.init_gevent() - import psycogreen.gevent # type: ignore psycogreen.gevent.patch_psycopg() - from app_factory import create_app app = create_app() celery = app.extensions["celery"] - if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 3a258be28f..b2ffb7c14a 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -93,7 +93,6 @@ def initialize_extensions(app: DifyApp): if dify_config.DEBUG: logging.info(f"Skipped {short_name}") continue - start_time = time.perf_counter() ext.init_app(app) end_time = time.perf_counter() @@ -108,5 +107,4 @@ def create_migrations_app(): # Initialize only required extensions ext_database.init_app(app) ext_migrate.init_app(app) - return app diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 20f8c40427..04ccd05be2 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -33,7 +33,6 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource): remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME") if not remote_source_name: return {} - remote_source: RemoteSettingsSource | None = None match remote_source_name: case RemoteSettingsSourceName.APOLLO: @@ -43,15 +42,12 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource): case _: logger.warning(f"Unsupported remote source: {remote_source_name}") return {} - d: dict[str, Any] = {} - for field_name, field in self.settings_cls.model_fields.items(): field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name) field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex) if field_value is not None: d[field_key] = field_value - return d @@ -86,7 +82,6 @@ class DifyConfig( # please consider to arrange it in the proper config group of existed or added # for better readability and maintainability. # Thanks for your concentration and consideration. - @classmethod def settings_customise_sources( cls, diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 63f4dfba63..904d6c0570 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -11,23 +11,19 @@ class DeploymentConfig(BaseSettings): description="Name of the application, used for identification and logging purposes", default="langgenius/dify", ) - DEBUG: bool = Field( description="Enable debug mode for additional logging and development features", default=False, ) - # Request logging configuration ENABLE_REQUEST_LOGGING: bool = Field( description="Enable request and response body logging", default=False, ) - EDITION: str = Field( description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", default="SELF_HOSTED", ) - DEPLOY_ENV: str = Field( description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION", default="PRODUCTION", diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..0a5061ae70 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -13,7 +13,6 @@ class EnterpriseFeatureConfig(BaseSettings): "Before using, please contact business@dify.ai by email to inquire about licensing matters.", default=False, ) - CAN_REPLACE_LOGO: bool = Field( description="Allow customization of the enterprise logo.", default=False, diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index f9c4d73463..46a9eeab95 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -13,23 +13,19 @@ class NotionConfig(BaseSettings): description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_CLIENT_SECRET: Optional[str] = Field( description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_INTEGRATION_TYPE: Optional[str] = Field( description="Type of Notion integration." " Set to 'internal' for internal integrations, or None for public integrations.", default=None, ) - NOTION_INTERNAL_SECRET: Optional[str] = Field( description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.", default=None, ) - NOTION_INTEGRATION_TOKEN: Optional[str] = Field( description="Integration token for Notion API access. Used for direct API calls without OAuth flow.", default=None, diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index f76a6bdb95..46cf24bfef 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -14,13 +14,11 @@ class SentryConfig(BaseSettings): " This is the unique identifier of your Sentry project, used to send events to the correct project.", default=None, ) - SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( description="Sample rate for Sentry performance monitoring traces." " Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.", default=1.0, ) - SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( description="Sample rate for Sentry profiling." " Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.", diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index df15b92c35..aa9a08299b 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -26,22 +26,18 @@ class SecurityConfig(BaseSettings): "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", default="", ) - RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( description="Duration in minutes for which a password reset token remains valid", default=5, ) - LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", default=False, ) - ADMIN_API_KEY_ENABLE: bool = Field( description="Whether to enable admin api key for authentication", default=False, ) - ADMIN_API_KEY: Optional[str] = Field( description="admin api key for authentication", default=None, @@ -76,62 +72,50 @@ class CodeExecutionSandboxConfig(BaseSettings): description="URL endpoint for the code execution service", default=HttpUrl("http://sandbox:8194"), ) - CODE_EXECUTION_API_KEY: str = Field( description="API key for accessing the code execution service", default="dify-sandbox", ) - CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( description="Connection timeout in seconds for code execution requests", default=10.0, ) - CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( description="Read timeout in seconds for code execution requests", default=60.0, ) - CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( description="Write timeout in seconds for code execution request", default=10.0, ) - CODE_MAX_NUMBER: PositiveInt = Field( description="Maximum allowed numeric value in code execution", default=9223372036854775807, ) - CODE_MIN_NUMBER: NegativeInt = Field( description="Minimum allowed numeric value in code execution", default=-9223372036854775807, ) - CODE_MAX_DEPTH: PositiveInt = Field( description="Maximum allowed depth for nested structures in code execution", default=5, ) - CODE_MAX_PRECISION: PositiveInt = Field( description="Maximum number of decimal places for floating-point numbers in code execution", default=20, ) - CODE_MAX_STRING_LENGTH: PositiveInt = Field( description="Maximum allowed length for strings in code execution", default=80000, ) - CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( description="Maximum allowed length for string arrays in code execution", default=30, ) - CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( description="Maximum allowed length for object arrays in code execution", default=30, ) - CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( description="Maximum allowed length for numeric arrays in code execution", default=1000, @@ -147,29 +131,23 @@ class PluginConfig(BaseSettings): description="Plugin API URL", default=HttpUrl("http://localhost:5002"), ) - PLUGIN_DAEMON_KEY: str = Field( description="Plugin API key", default="plugin-api-key", ) - INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") - PLUGIN_REMOTE_INSTALL_HOST: str = Field( description="Plugin Remote Install Host", default="localhost", ) - PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field( description="Plugin Remote Install Port", default=5003, ) - PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field( description="Maximum allowed size for plugin packages in bytes", default=15728640, ) - PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field( description="Maximum allowed size for plugin bundles in bytes", default=15728640 * 12, @@ -185,7 +163,6 @@ class MarketplaceConfig(BaseSettings): description="Enable or disable marketplace", default=True, ) - MARKETPLACE_API_URL: HttpUrl = Field( description="Marketplace API URL", default=HttpUrl("https://marketplace.dify.ai"), @@ -202,22 +179,18 @@ class EndpointConfig(BaseSettings): "used for login authentication callback or notion integration callbacks", default="", ) - CONSOLE_WEB_URL: str = Field( description="Base URL for the console web interface,used for frontend references and CORS configuration", default="", ) - SERVICE_API_URL: str = Field( description="Base URL for the service API, displayed to users for API access", default="", ) - APP_WEB_URL: str = Field( description="Base URL for the web application, used for frontend references", default="", ) - ENDPOINT_URL_TEMPLATE: str = Field( description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}" ) @@ -236,7 +209,6 @@ class FileAccessConfig(BaseSettings): alias_priority=1, default="", ) - FILES_ACCESS_TIMEOUT: int = Field( description="Expiration time in seconds for file access URLs", default=300, @@ -252,32 +224,26 @@ class FileUploadConfig(BaseSettings): description="Maximum allowed file size for uploads in megabytes", default=15, ) - UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( description="Maximum number of files allowed in a single upload batch", default=5, ) - UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( description="Maximum allowed image file size for uploads in megabytes", default=10, ) - UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field( description="video file size limit in Megabytes for uploading files", default=100, ) - UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field( description="audio file size limit in Megabytes for uploading files", default=50, ) - BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( description="Maximum number of files allowed in a batch upload operation", default=20, ) - WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field( description="Maximum number of files allowed in a workflow upload operation", default=10, @@ -293,7 +259,6 @@ class HttpConfig(BaseSettings): description="Enable or disable gzip compression for HTTP responses", default=False, ) - inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( description="Comma-separated list of allowed origins for CORS in the console", validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), @@ -317,70 +282,56 @@ class HttpConfig(BaseSettings): HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests") ] = 10 - HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests") ] = 60 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests") ] = 20 - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( description="Maximum allowed size in bytes for binary data in HTTP requests", default=10 * 1024 * 1024, ) - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( description="Maximum allowed size in bytes for text data in HTTP requests", default=1 * 1024 * 1024, ) - HTTP_REQUEST_NODE_SSL_VERIFY: bool = Field( description="Enable or disable SSL verification for HTTP requests", default=True, ) - SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field( description="Maximum number of retries for network requests (SSRF)", default=3, ) - SSRF_PROXY_ALL_URL: Optional[str] = Field( description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTP_URL: Optional[str] = Field( description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTPS_URL: Optional[str] = Field( description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_DEFAULT_TIME_OUT: PositiveFloat = Field( description="The default timeout period used for network requests (SSRF)", default=5, ) - SSRF_DEFAULT_CONNECT_TIME_OUT: PositiveFloat = Field( description="The default connect timeout period used for network requests (SSRF)", default=5, ) - SSRF_DEFAULT_READ_TIME_OUT: PositiveFloat = Field( description="The default read timeout period used for network requests (SSRF)", default=5, ) - SSRF_DEFAULT_WRITE_TIME_OUT: PositiveFloat = Field( description="The default write timeout period used for network requests (SSRF)", default=5, ) - RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field( description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers" " when the app is behind a single trusted reverse proxy.", @@ -397,7 +348,6 @@ class InnerAPIConfig(BaseSettings): description="Enable or disable the internal API", default=False, ) - INNER_API_KEY: Optional[str] = Field( description="API key for accessing the internal API", default=None, @@ -413,32 +363,26 @@ class LoggingConfig(BaseSettings): description="Logging level, default to INFO. Set to ERROR for production environments.", default="INFO", ) - LOG_FILE: Optional[str] = Field( description="File path for log output.", default=None, ) - LOG_FILE_MAX_SIZE: PositiveInt = Field( description="Maximum file size for file rotation retention, the unit is megabytes (MB)", default=20, ) - LOG_FILE_BACKUP_COUNT: PositiveInt = Field( description="Maximum file backup count file rotation retention", default=5, ) - LOG_FORMAT: str = Field( description="Format string for log messages", default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) - LOG_DATEFORMAT: Optional[str] = Field( description="Date format string for log timestamps", default=None, ) - LOG_TZ: Optional[str] = Field( description="Timezone for log timestamps (e.g., 'America/New_York')", default="UTC", @@ -454,7 +398,6 @@ class ModelLoadBalanceConfig(BaseSettings): description="Enable or disable load balancing for models", default=False, ) - PLUGIN_BASED_TOKEN_COUNTING_ENABLED: bool = Field( description="Enable or disable plugin based token counting. If disabled, token counting will return 0.", default=False, @@ -492,22 +435,18 @@ class WorkflowConfig(BaseSettings): description="Maximum number of steps allowed in a single workflow execution", default=500, ) - WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( description="Maximum execution time in seconds for a single workflow", default=1200, ) - WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( description="Maximum allowed depth for nested workflow calls", default=5, ) - WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field( description="Maximum allowed depth for nested parallel executions", default=3, ) - MAX_VARIABLE_SIZE: PositiveInt = Field( description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", default=200 * 1024, @@ -523,7 +462,6 @@ class WorkflowNodeExecutionConfig(BaseSettings): description="Maximum number of submitted thread count in a ThreadPool for parallel node execution", default=100, ) - WORKFLOW_NODE_EXECUTION_STORAGE: str = Field( default="rdbms", description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'", @@ -539,42 +477,34 @@ class AuthConfig(BaseSettings): description="Redirect path for OAuth authentication callbacks", default="/console/api/oauth/authorize", ) - GITHUB_CLIENT_ID: Optional[str] = Field( description="GitHub OAuth client ID", default=None, ) - GITHUB_CLIENT_SECRET: Optional[str] = Field( description="GitHub OAuth client secret", default=None, ) - GOOGLE_CLIENT_ID: Optional[str] = Field( description="Google OAuth client ID", default=None, ) - GOOGLE_CLIENT_SECRET: Optional[str] = Field( description="Google OAuth client secret", default=None, ) - ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field( description="Expiration time for access tokens in minutes", default=60, ) - REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field( description="Expiration time for refresh tokens in days", default=30, ) - LOGIN_LOCKOUT_DURATION: PositiveInt = Field( description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", default=86400, ) - FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field( description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.", default=86400, @@ -612,57 +542,46 @@ class MailConfig(BaseSettings): description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", default=None, ) - MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( description="Default email address to use as the sender", default=None, ) - RESEND_API_KEY: Optional[str] = Field( description="API key for Resend email service", default=None, ) - RESEND_API_URL: Optional[str] = Field( description="API URL for Resend email service", default=None, ) - SMTP_SERVER: Optional[str] = Field( description="SMTP server hostname", default=None, ) - SMTP_PORT: Optional[int] = Field( description="SMTP server port number", default=465, ) - SMTP_USERNAME: Optional[str] = Field( description="Username for SMTP authentication", default=None, ) - SMTP_PASSWORD: Optional[str] = Field( description="Password for SMTP authentication", default=None, ) - SMTP_USE_TLS: bool = Field( description="Enable TLS encryption for SMTP connections", default=False, ) - SMTP_OPPORTUNISTIC_TLS: bool = Field( description="Enable opportunistic TLS for SMTP connections", default=False, ) - EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field( description="Maximum number of emails allowed to be sent from the same IP address in a minute", default=50, ) - SENDGRID_API_KEY: Optional[str] = Field( description="API key for SendGrid service", default=None, @@ -679,23 +598,19 @@ class RagEtlConfig(BaseSettings): description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", default="dify", ) - KEYWORD_DATA_SOURCE_TYPE: str = Field( description="Data source type for keyword extraction" " ('database' or other supported types), default to 'database'", default="database", ) - UNSTRUCTURED_API_URL: Optional[str] = Field( description="API URL for Unstructured.io service", default=None, ) - UNSTRUCTURED_API_KEY: Optional[str] = Field( description="API key for Unstructured.io service", default="", ) - SCARF_NO_ANALYTICS: Optional[str] = Field( description="This is about whether to disable Scarf analytics in Unstructured library.", default="false", @@ -711,27 +626,22 @@ class DataSetConfig(BaseSettings): description="Interval in days for dataset cleanup operations - plan: sandbox", default=30, ) - PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field( description="Interval in days for dataset cleanup operations - plan: pro and team", default=7, ) - DATASET_OPERATOR_ENABLED: bool = Field( description="Enable or disable dataset operator functionality", default=False, ) - TIDB_SERVERLESS_NUMBER: PositiveInt = Field( description="number of tidb serverless cluster", default=500, ) - CREATE_TIDB_SERVICE_JOB_ENABLED: bool = Field( description="Enable or disable create tidb service job", default=False, ) - PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field( description="Interval in days for message cleanup operations - plan: sandbox", default=30, @@ -758,7 +668,6 @@ class IndexingConfig(BaseSettings): description="Maximum token length for text segmentation during indexing", default=4000, ) - CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field( description="Maximum number of child chunks to preview", default=50, @@ -784,27 +693,22 @@ class PositionConfig(BaseSettings): description="Comma-separated list of pinned model providers", default="", ) - POSITION_PROVIDER_INCLUDES: str = Field( description="Comma-separated list of included model providers", default="", ) - POSITION_PROVIDER_EXCLUDES: str = Field( description="Comma-separated list of excluded model providers", default="", ) - POSITION_TOOL_PINS: str = Field( description="Comma-separated list of pinned tools", default="", ) - POSITION_TOOL_INCLUDES: str = Field( description="Comma-separated list of included tools", default="", ) - POSITION_TOOL_EXCLUDES: str = Field( description="Comma-separated list of excluded tools", default="", @@ -867,7 +771,6 @@ class AccountConfig(BaseSettings): description="Duration in minutes for which a account deletion token remains valid", default=5, ) - EDUCATION_ENABLED: bool = Field( description="whether to enable education identity", default=False, diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 18ef1ed45b..5dce285d70 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -14,18 +14,15 @@ class HostedCreditConfig(BaseSettings): """ Get credit value for a specific model name. Returns 1 if model is not found in configuration (default credit). - :param model_name: The name of the model to search for :return: The credit value for the model """ if not self.HOSTED_MODEL_CREDIT_CONFIG: return 1 - try: credit_map = dict( item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item ) - # Search for matching model pattern for pattern, credit in credit_map.items(): if pattern.strip() == model_name: @@ -44,22 +41,18 @@ class HostedOpenAiConfig(BaseSettings): description="API key for hosted OpenAI service", default=None, ) - HOSTED_OPENAI_API_BASE: Optional[str] = Field( description="Base URL for hosted OpenAI API", default=None, ) - HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( description="Organization ID for hosted OpenAI service", default=None, ) - HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( description="Enable trial access to hosted OpenAI service", default=False, ) - HOSTED_OPENAI_TRIAL_MODELS: str = Field( description="Comma-separated list of available models for trial access", default="gpt-3.5-turbo," @@ -71,17 +64,14 @@ class HostedOpenAiConfig(BaseSettings): "gpt-3.5-turbo-0125," "text-davinci-003", ) - HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( description="Quota limit for hosted OpenAI service usage", default=200, ) - HOSTED_OPENAI_PAID_ENABLED: bool = Field( description="Enable paid access to hosted OpenAI service", default=False, ) - HOSTED_OPENAI_PAID_MODELS: str = Field( description="Comma-separated list of available models for paid access", default="gpt-4," @@ -109,17 +99,14 @@ class HostedAzureOpenAiConfig(BaseSettings): description="Enable hosted Azure OpenAI service", default=False, ) - HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( description="API key for hosted Azure OpenAI service", default=None, ) - HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( description="Base URL for hosted Azure OpenAI API", default=None, ) - HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( description="Quota limit for hosted Azure OpenAI service usage", default=200, @@ -135,22 +122,18 @@ class HostedAnthropicConfig(BaseSettings): description="Base URL for hosted Anthropic API", default=None, ) - HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( description="API key for hosted Anthropic service", default=None, ) - HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( description="Enable trial access to hosted Anthropic service", default=False, ) - HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( description="Quota limit for hosted Anthropic service usage", default=600000, ) - HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( description="Enable paid access to hosted Anthropic service", default=False, @@ -199,7 +182,6 @@ class HostedModerationConfig(BaseSettings): description="Enable hosted Moderation service", default=False, ) - HOSTED_MODERATION_PROVIDERS: str = Field( description="Comma-separated list of moderation providers", default="", @@ -215,7 +197,6 @@ class HostedFetchAppTemplateConfig(BaseSettings): description="Mode for fetching app templates: remote, db, or builtin default to remote,", default="remote", ) - HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( description="Domain for fetching remote app templates", default="https://tmpl.dify.ai", diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 427602676f..83948b46b2 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -65,7 +65,6 @@ class StorageConfig(BaseSettings): "'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", default="opendal", ) - STORAGE_LOCAL_PATH: str = Field( description="Path for local storage when STORAGE_TYPE is set to 'local'.", default="storage", @@ -79,7 +78,6 @@ class VectorStoreConfig(BaseSettings): " Set to None if not using a vector store.", default=None, ) - VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( description="Enable whitelist for vector store.", default=False, @@ -99,37 +97,30 @@ class DatabaseConfig(BaseSettings): description="Hostname or IP address of the database server.", default="localhost", ) - DB_PORT: PositiveInt = Field( description="Port number for database connection.", default=5432, ) - DB_USERNAME: str = Field( description="Username for database authentication.", default="postgres", ) - DB_PASSWORD: str = Field( description="Password for database authentication.", default="", ) - DB_DATABASE: str = Field( description="Name of the database to connect to.", default="dify", ) - DB_CHARSET: str = Field( description="Character set for database connection.", default="", ) - DB_EXTRAS: str = Field( description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'", default="", ) - SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( description="Database URI scheme for SQLAlchemy connection.", default="postgresql", @@ -151,27 +142,22 @@ class DatabaseConfig(BaseSettings): description="Maximum number of database connections in the pool.", default=30, ) - SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( description="Maximum number of connections that can be created beyond the pool_size.", default=10, ) - SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( description="Number of seconds after which a connection is automatically recycled.", default=3600, ) - SQLALCHEMY_POOL_PRE_PING: bool = Field( description="If True, enables connection pool pre-ping feature to check connections.", default=False, ) - SQLALCHEMY_ECHO: bool | str = Field( description="If True, SQLAlchemy will log all SQL statements.", default=False, ) - RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field( description="Number of processes for the retrieval service, default to CPU cores.", default=os.cpu_count() or 1, @@ -190,9 +176,7 @@ class DatabaseConfig(BaseSettings): merged_options = f"{options} {timezone_opt}" else: merged_options = timezone_opt - connect_args = {"options": merged_options} - return { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, @@ -207,22 +191,18 @@ class CeleryConfig(DatabaseConfig): description="Backend for Celery task results. Options: 'database', 'redis'.", default="database", ) - CELERY_BROKER_URL: Optional[str] = Field( description="URL of the message broker for Celery tasks.", default=None, ) - CELERY_USE_SENTINEL: Optional[bool] = Field( description="Whether to use Redis Sentinel for high availability.", default=False, ) - CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( description="Name of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_PASSWORD: Optional[str] = Field( description="Password of the Redis Sentinel master.", default=None, @@ -254,7 +234,6 @@ class InternalTestConfig(BaseSettings): description="Internal test AWS secret access key", default=None, ) - AWS_ACCESS_KEY_ID: Optional[str] = Field( description="Internal test AWS access key ID", default=None, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 916f52e165..8e0af10366 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -13,82 +13,66 @@ class RedisConfig(BaseSettings): description="Hostname or IP address of the Redis server", default="localhost", ) - REDIS_PORT: PositiveInt = Field( description="Port number on which the Redis server is listening", default=6379, ) - REDIS_USERNAME: Optional[str] = Field( description="Username for Redis authentication (if required)", default=None, ) - REDIS_PASSWORD: Optional[str] = Field( description="Password for Redis authentication (if required)", default=None, ) - REDIS_DB: NonNegativeInt = Field( description="Redis database number to use (0-15)", default=0, ) - REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, ) - REDIS_USE_SENTINEL: Optional[bool] = Field( description="Enable Redis Sentinel mode for high availability", default=False, ) - REDIS_SENTINELS: Optional[str] = Field( description="Comma-separated list of Redis Sentinel nodes (host:port)", default=None, ) - REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( description="Name of the Redis Sentinel service to monitor", default=None, ) - REDIS_SENTINEL_USERNAME: Optional[str] = Field( description="Username for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_PASSWORD: Optional[str] = Field( description="Password for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( description="Socket timeout in seconds for Redis Sentinel connections", default=0.1, ) - REDIS_USE_CLUSTERS: bool = Field( description="Enable Redis Clusters mode for high availability", default=False, ) - REDIS_CLUSTERS: Optional[str] = Field( description="Comma-separated list of Redis Clusters nodes (host:port)", default=None, ) - REDIS_CLUSTERS_PASSWORD: Optional[str] = Field( description="Password for Redis Clusters authentication (if required)", default=None, ) - REDIS_SERIALIZATION_PROTOCOL: int = Field( description="Redis serialization protocol (RESP) version", default=3, ) - REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field( description="Enable client side cache in redis", default=False, diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 07eb527170..f67fd6bf69 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -13,32 +13,26 @@ class AliyunOSSStorageConfig(BaseSettings): description="Name of the Aliyun OSS bucket to store and retrieve objects", default=None, ) - ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( description="Access key ID for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( description="Secret access key for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_ENDPOINT: Optional[str] = Field( description="URL of the Aliyun OSS endpoint for your chosen region", default=None, ) - ALIYUN_OSS_REGION: Optional[str] = Field( description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')", default=None, ) - ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')", default=None, ) - ALIYUN_OSS_PATH: Optional[str] = Field( description="Base path within the bucket to store objects (e.g., 'my-app-data/')", default=None, diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index e14c210718..49cd35188b 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -13,32 +13,26 @@ class S3StorageConfig(BaseSettings): description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')", default=None, ) - S3_REGION: Optional[str] = Field( description="Region where the S3 bucket is located (e.g., 'us-east-1')", default=None, ) - S3_BUCKET_NAME: Optional[str] = Field( description="Name of the S3 bucket to store and retrieve objects", default=None, ) - S3_ACCESS_KEY: Optional[str] = Field( description="Access key ID for authenticating with the S3 service", default=None, ) - S3_SECRET_KEY: Optional[str] = Field( description="Secret access key for authenticating with the S3 service", default=None, ) - S3_ADDRESS_STYLE: Literal["auto", "virtual", "path"] = Field( description="S3 addressing style: 'auto', 'path', or 'virtual'", default="auto", ) - S3_USE_AWS_MANAGED_IAM: bool = Field( description="Use AWS managed IAM roles for authentication instead of access/secret keys", default=False, diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index b7ab5247a9..29350946e2 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -13,17 +13,14 @@ class AzureBlobStorageConfig(BaseSettings): description="Name of the Azure Storage account (e.g., 'mystorageaccount')", default=None, ) - AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( description="Access key for authenticating with the Azure Storage account", default=None, ) - AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( description="Name of the Azure Blob container to store and retrieve objects", default=None, ) - AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')", default=None, diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py index e7913b0acc..89b20ad553 100644 --- a/api/configs/middleware/storage/baidu_obs_storage_config.py +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -13,17 +13,14 @@ class BaiduOBSStorageConfig(BaseSettings): description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - BAIDU_OBS_ACCESS_KEY: Optional[str] = Field( description="Access Key ID for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_SECRET_KEY: Optional[str] = Field( description="Secret Access Key for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_ENDPOINT: Optional[str] = Field( description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')", default=None, diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e5d763d7f5..679ac939aa 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -13,7 +13,6 @@ class GoogleCloudStorageConfig(BaseSettings): description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')", default=None, ) - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( description="Base64-encoded JSON key file for Google Cloud service account authentication", default=None, diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py index be983b5187..4ecfccfd5d 100644 --- a/api/configs/middleware/storage/huawei_obs_storage_config.py +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -13,17 +13,14 @@ class HuaweiCloudOBSStorageConfig(BaseSettings): description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( description="Access Key ID for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( description="Secret Access Key for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SERVER: Optional[str] = Field( description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", default=None, diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index edc245bcac..88de5ecd88 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -13,22 +13,18 @@ class OCIStorageConfig(BaseSettings): description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')", default=None, ) - OCI_REGION: Optional[str] = Field( description="OCI region where the bucket is located (e.g., 'us-phoenix-1')", default=None, ) - OCI_BUCKET_NAME: Optional[str] = Field( description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')", default=None, ) - OCI_ACCESS_KEY: Optional[str] = Field( description="Access key (also known as API key) for authenticating with OCI Object Storage", default=None, ) - OCI_SECRET_KEY: Optional[str] = Field( description="Secret key associated with the access key for authenticating with OCI Object Storage", default=None, diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py index dcf7c20cf9..0f6b70b39c 100644 --- a/api/configs/middleware/storage/supabase_storage_config.py +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -13,12 +13,10 @@ class SupabaseStorageConfig(BaseSettings): description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", default=None, ) - SUPABASE_API_KEY: Optional[str] = Field( description="API KEY for authenticating with Supabase", default=None, ) - SUPABASE_URL: Optional[str] = Field( description="URL of the Supabase", default=None, diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 255c4e8938..13033a6cd3 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -13,22 +13,18 @@ class TencentCloudCOSStorageConfig(BaseSettings): description="Name of the Tencent Cloud COS bucket to store and retrieve objects", default=None, ) - TENCENT_COS_REGION: Optional[str] = Field( description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')", default=None, ) - TENCENT_COS_SECRET_ID: Optional[str] = Field( description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SECRET_KEY: Optional[str] = Field( description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SCHEME: Optional[str] = Field( description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index 06c3ae4d3e..cc9e93cb8b 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -13,22 +13,18 @@ class VolcengineTOSStorageConfig(BaseSettings): description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')", default=None, ) - VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( description="Access Key ID for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( description="Secret Access Key for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')", default=None, ) - VOLCENGINE_TOS_REGION: Optional[str] = Field( description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')", default=None, diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 44742c2e2f..6e37f0aed4 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -13,32 +13,26 @@ class BaiduVectorDBConfig(BaseSettings): description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", default=None, ) - BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", default=30000, ) - BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( description="Account for authenticating with the Baidu Vector Database", default=None, ) - BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( description="API key for authenticating with the Baidu Vector Database service", default=None, ) - BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( description="Name of the specific Baidu Vector Database to connect to", default=None, ) - BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( description="Number of shards for the Baidu Vector Database (default is 1)", default=1, ) - BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( description="Number of replicas for the Baidu Vector Database (default is 3)", default=3, diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index e83a9902de..498cb25026 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -13,27 +13,22 @@ class ChromaConfig(BaseSettings): description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')", default=None, ) - CHROMA_PORT: PositiveInt = Field( description="Port number on which the Chroma server is listening (default is 8000)", default=8000, ) - CHROMA_TENANT: Optional[str] = Field( description="Tenant identifier for multi-tenancy support in Chroma", default=None, ) - CHROMA_DATABASE: Optional[str] = Field( description="Name of the Chroma database to connect to", default=None, ) - CHROMA_AUTH_PROVIDER: Optional[str] = Field( description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)", default=None, ) - CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( description="Authentication credentials for Chroma (format depends on the auth provider)", default=None, diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py index b81cbf8959..4a34d2eb03 100644 --- a/api/configs/middleware/vdb/couchbase_config.py +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -13,22 +13,18 @@ class CouchbaseConfig(BaseSettings): description="COUCHBASE connection string", default=None, ) - COUCHBASE_USER: Optional[str] = Field( description="COUCHBASE user", default=None, ) - COUCHBASE_PASSWORD: Optional[str] = Field( description="COUCHBASE password", default=None, ) - COUCHBASE_BUCKET_NAME: Optional[str] = Field( description="COUCHBASE bucket name", default=None, ) - COUCHBASE_SCOPE_NAME: Optional[str] = Field( description="COUCHBASE scope name", default=None, diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index df8182985d..44c6b0da19 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -13,17 +13,14 @@ class ElasticsearchConfig(BaseSettings): description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')", default="127.0.0.1", ) - ELASTICSEARCH_PORT: PositiveInt = Field( description="Port number on which the Elasticsearch server is listening (default is 9200)", default=9200, ) - ELASTICSEARCH_USERNAME: Optional[str] = Field( description="Username for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) - ELASTICSEARCH_PASSWORD: Optional[str] = Field( description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", diff --git a/api/configs/middleware/vdb/huawei_cloud_config.py b/api/configs/middleware/vdb/huawei_cloud_config.py index 2290c60499..bb9782ad0d 100644 --- a/api/configs/middleware/vdb/huawei_cloud_config.py +++ b/api/configs/middleware/vdb/huawei_cloud_config.py @@ -13,12 +13,10 @@ class HuaweiCloudConfig(BaseSettings): description="Hostname or IP address of the Huawei cloud search service instance", default=None, ) - HUAWEI_CLOUD_USER: Optional[str] = Field( description="Username for authenticating with Huawei cloud search service", default=None, ) - HUAWEI_CLOUD_PASSWORD: Optional[str] = Field( description="Password for authenticating with Huawei cloud search service", default=None, diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index d398ef5bd8..bad06d2cbe 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -13,33 +13,27 @@ class MilvusConfig(BaseSettings): description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')", default="http://127.0.0.1:19530", ) - MILVUS_TOKEN: Optional[str] = Field( description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: Optional[str] = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_PASSWORD: Optional[str] = Field( description="Password for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_DATABASE: str = Field( description="Name of the Milvus database to connect to (default is 'default')", default="default", ) - MILVUS_ENABLE_HYBRID_SEARCH: bool = Field( description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with " "older versions", default=True, ) - MILVUS_ANALYZER_PARAMS: Optional[str] = Field( description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index b5bf98b3aa..b63f991b4b 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -11,27 +11,22 @@ class MyScaleConfig(BaseSettings): description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')", default="localhost", ) - MYSCALE_PORT: PositiveInt = Field( description="Port number on which the MyScale server is listening (default is 8123)", default=8123, ) - MYSCALE_USER: str = Field( description="Username for authenticating with MyScale (default is 'default')", default="default", ) - MYSCALE_PASSWORD: str = Field( description="Password for authenticating with MyScale (default is an empty string)", default="", ) - MYSCALE_DATABASE: str = Field( description="Name of the MyScale database to connect to (default is 'default')", default="default", ) - MYSCALE_FTS_PARAMS: str = Field( description="Additional parameters for MyScale Full Text Search index)", default="", diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 9b11a22732..dca856bb8a 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -13,27 +13,22 @@ class OceanBaseVectorConfig(BaseSettings): description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')", default=None, ) - OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field( description="Port number on which the OceanBase Vector server is listening (default is 2881)", default=2881, ) - OCEANBASE_VECTOR_USER: Optional[str] = Field( description="Username for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field( description="Password for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_DATABASE: Optional[str] = Field( description="Name of the OceanBase Vector database to connect to", default=None, ) - OCEANBASE_ENABLE_HYBRID_SEARCH: bool = Field( description="Enable hybrid search features (requires OceanBase >= 4.3.5.1). Set to false for compatibility " "with older versions", diff --git a/api/configs/middleware/vdb/opengauss_config.py b/api/configs/middleware/vdb/opengauss_config.py index 87ea292ab4..f06410b89b 100644 --- a/api/configs/middleware/vdb/opengauss_config.py +++ b/api/configs/middleware/vdb/opengauss_config.py @@ -13,37 +13,30 @@ class OpenGaussConfig(BaseSettings): description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')", default=None, ) - OPENGAUSS_PORT: PositiveInt = Field( description="Port number on which the OpenGauss server is listening (default is 6600)", default=6600, ) - OPENGAUSS_USER: Optional[str] = Field( description="Username for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_PASSWORD: Optional[str] = Field( description="Password for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_DATABASE: Optional[str] = Field( description="Name of the OpenGauss database to connect to", default=None, ) - OPENGAUSS_MIN_CONNECTION: PositiveInt = Field( description="Min connection of the OpenGauss database", default=1, ) - OPENGAUSS_MAX_CONNECTION: PositiveInt = Field( description="Max connection of the OpenGauss database", default=5, ) - OPENGAUSS_ENABLE_PQ: bool = Field( description="Enable openGauss PQ acceleration feature", default=False, diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9fd9b60194..2510fd17df 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -22,42 +22,34 @@ class OpenSearchConfig(BaseSettings): description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, ) - OPENSEARCH_PORT: PositiveInt = Field( description="Port number on which the OpenSearch server is listening (default is 9200)", default=9200, ) - OPENSEARCH_SECURE: bool = Field( description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", default=False, ) - OPENSEARCH_VERIFY_CERTS: bool = Field( description="Whether to verify SSL certificates for HTTPS connections (recommended to set True in production)", default=True, ) - OPENSEARCH_AUTH_METHOD: AuthMethod = Field( description="Authentication method for OpenSearch connection (default is 'basic')", default=AuthMethod.BASIC, ) - OPENSEARCH_USER: Optional[str] = Field( description="Username for authenticating with OpenSearch", default=None, ) - OPENSEARCH_PASSWORD: Optional[str] = Field( description="Password for authenticating with OpenSearch", default=None, ) - OPENSEARCH_AWS_REGION: Optional[str] = Field( description="AWS region for OpenSearch (e.g. 'us-west-2')", default=None, ) - OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index ea39909ef4..7ab8fbb201 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -13,33 +13,27 @@ class OracleConfig(BaseSettings): description="Username for authenticating with the Oracle database", default=None, ) - ORACLE_PASSWORD: Optional[str] = Field( description="Password for authenticating with the Oracle database", default=None, ) - ORACLE_DSN: Optional[str] = Field( description="Oracle database connection string. For traditional database, use format 'host:port/service_name'. " "For autonomous database, use the service name from tnsnames.ora in the wallet", default=None, ) - ORACLE_CONFIG_DIR: Optional[str] = Field( description="Directory containing the tnsnames.ora configuration file. Only used in thin mode connection", default=None, ) - ORACLE_WALLET_LOCATION: Optional[str] = Field( description="Oracle wallet directory path containing the wallet files for secure connection", default=None, ) - ORACLE_WALLET_PASSWORD: Optional[str] = Field( description="Password to decrypt the Oracle wallet, if it is encrypted", default=None, ) - ORACLE_IS_AUTONOMOUS: bool = Field( description="Flag indicating whether connecting to Oracle Autonomous Database", default=False, diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 9f5f7284d7..af50d4a5d1 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -13,37 +13,30 @@ class PGVectorConfig(BaseSettings): description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')", default=None, ) - PGVECTOR_PORT: PositiveInt = Field( description="Port number on which the PostgreSQL server is listening (default is 5433)", default=5433, ) - PGVECTOR_USER: Optional[str] = Field( description="Username for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_PASSWORD: Optional[str] = Field( description="Password for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_DATABASE: Optional[str] = Field( description="Name of the PostgreSQL database to connect to", default=None, ) - PGVECTOR_MIN_CONNECTION: PositiveInt = Field( description="Min connection of the PostgreSQL database", default=1, ) - PGVECTOR_MAX_CONNECTION: PositiveInt = Field( description="Max connection of the PostgreSQL database", default=5, ) - PGVECTOR_PG_BIGM: bool = Field( description="Whether to use pg_bigm module for full text search", default=False, diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index fa3bca5bb7..76386698df 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -13,22 +13,18 @@ class PGVectoRSConfig(BaseSettings): description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')", default=None, ) - PGVECTO_RS_PORT: PositiveInt = Field( description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)", default=5431, ) - PGVECTO_RS_USER: Optional[str] = Field( description="Username for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_PASSWORD: Optional[str] = Field( description="Password for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_DATABASE: Optional[str] = Field( description="Name of the PostgreSQL database with PGVecto.RS extension to connect to", default=None, diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index 0a753eddec..2a10f909ee 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -13,27 +13,22 @@ class QdrantConfig(BaseSettings): description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')", default=None, ) - QDRANT_API_KEY: Optional[str] = Field( description="API key for authenticating with the Qdrant server", default=None, ) - QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( description="Timeout in seconds for Qdrant client operations (default is 20 seconds)", default=20, ) - QDRANT_GRPC_ENABLED: bool = Field( description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)", default=False, ) - QDRANT_GRPC_PORT: PositiveInt = Field( description="Port number for gRPC connection to Qdrant server (default is 6334)", default=6334, ) - QDRANT_REPLICATION_FACTOR: PositiveInt = Field( description="Replication factor for Qdrant collections (default is 1)", default=1, diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index 5ffbea7b19..5e24e066a5 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -13,22 +13,18 @@ class RelytConfig(BaseSettings): description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')", default=None, ) - RELYT_PORT: PositiveInt = Field( description="Port number on which the Relyt server is listening (default is 9200)", default=9200, ) - RELYT_USER: Optional[str] = Field( description="Username for authenticating with the Relyt database", default=None, ) - RELYT_PASSWORD: Optional[str] = Field( description="Password for authenticating with the Relyt database", default=None, ) - RELYT_DATABASE: Optional[str] = Field( description="Name of the Relyt database to connect to (default is 'default')", default="default", diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index c4dcc0d465..ce31dd7b54 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -13,17 +13,14 @@ class TableStoreConfig(BaseSettings): description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')", default=None, ) - TABLESTORE_INSTANCE_NAME: Optional[str] = Field( description="Instance name to access TableStore server (eg. 'instance-name')", default=None, ) - TABLESTORE_ACCESS_KEY_ID: Optional[str] = Field( description="AccessKey id for the instance name", default=None, ) - TABLESTORE_ACCESS_KEY_SECRET: Optional[str] = Field( description="AccessKey secret for the instance name", default=None, diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index a51823c3f3..14bb023358 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -13,42 +13,34 @@ class TencentVectorDBConfig(BaseSettings): description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')", default=None, ) - TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( description="API key for authenticating with the Tencent Vector Database service", default=None, ) - TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)", default=30, ) - TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( description="Username for authenticating with the Tencent Vector Database (if required)", default=None, ) - TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( description="Password for authenticating with the Tencent Vector Database (if required)", default=None, ) - TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( description="Number of shards for the Tencent Vector Database (default is 1)", default=1, ) - TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( description="Number of replicas for the Tencent Vector Database (default is 2)", default=2, ) - TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( description="Name of the specific Tencent Vector Database to connect to", default=None, ) - TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field( description="Enable hybrid search features", default=False, diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py index d2625af264..84acf16444 100644 --- a/api/configs/middleware/vdb/tidb_on_qdrant_config.py +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -13,57 +13,46 @@ class TidbOnQdrantConfig(BaseSettings): description="Tidb on Qdrant url", default=None, ) - TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( description="Tidb on Qdrant api key", default=None, ) - TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( description="Tidb on Qdrant client timeout in seconds", default=20, ) - TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( description="whether enable grpc support for Tidb on Qdrant connection", default=False, ) - TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( description="Tidb on Qdrant grpc port", default=6334, ) - TIDB_PUBLIC_KEY: Optional[str] = Field( description="Tidb account public key", default=None, ) - TIDB_PRIVATE_KEY: Optional[str] = Field( description="Tidb account private key", default=None, ) - TIDB_API_URL: Optional[str] = Field( description="Tidb API url", default=None, ) - TIDB_IAM_API_URL: Optional[str] = Field( description="Tidb IAM API url", default=None, ) - TIDB_REGION: Optional[str] = Field( description="Tidb serverless region", default="regions/aws-us-east-1", ) - TIDB_PROJECT_ID: Optional[str] = Field( description="Tidb project id", default=None, ) - TIDB_SPEND_LIMIT: Optional[int] = Field( description="Tidb spend limit", default=100, diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index bc68be69d8..8e4b088a5a 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -13,22 +13,18 @@ class TiDBVectorConfig(BaseSettings): description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')", default=None, ) - TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( description="Port number on which the TiDB Vector server is listening (default is 4000)", default=4000, ) - TIDB_VECTOR_USER: Optional[str] = Field( description="Username for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_PASSWORD: Optional[str] = Field( description="Password for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_DATABASE: Optional[str] = Field( description="Name of the TiDB Vector database to connect to", default=None, diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py index 412c56374a..57aef3b0e8 100644 --- a/api/configs/middleware/vdb/upstash_config.py +++ b/api/configs/middleware/vdb/upstash_config.py @@ -13,7 +13,6 @@ class UpstashConfig(BaseSettings): description="URL of the upstash server (e.g., 'https://vector.upstash.io')", default=None, ) - UPSTASH_VECTOR_TOKEN: Optional[str] = Field( description="Token for authenticating with the upstash server", default=None, diff --git a/api/configs/middleware/vdb/vastbase_vector_config.py b/api/configs/middleware/vdb/vastbase_vector_config.py index 816d6df90a..a0c63f98d5 100644 --- a/api/configs/middleware/vdb/vastbase_vector_config.py +++ b/api/configs/middleware/vdb/vastbase_vector_config.py @@ -13,32 +13,26 @@ class VastbaseVectorConfig(BaseSettings): description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')", default=None, ) - VASTBASE_PORT: PositiveInt = Field( description="Port number on which the Vastbase server is listening (default is 5432)", default=5432, ) - VASTBASE_USER: Optional[str] = Field( description="Username for authenticating with the Vastbase database", default=None, ) - VASTBASE_PASSWORD: Optional[str] = Field( description="Password for authenticating with the Vastbase database", default=None, ) - VASTBASE_DATABASE: Optional[str] = Field( description="Name of the Vastbase database to connect to", default=None, ) - VASTBASE_MIN_CONNECTION: PositiveInt = Field( description="Min connection of the Vastbase database", default=1, ) - VASTBASE_MAX_CONNECTION: PositiveInt = Field( description="Max connection of the Vastbase database", default=5, diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index aba49ff670..ee4cd6e2a8 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -17,33 +17,27 @@ class VikingDBConfig(BaseSettings): "https://www.volcengine.com/docs/6291/65568", default=None, ) - VIKINGDB_SECRET_KEY: Optional[str] = Field( description="The Secret Key provided by Volcengine VikingDB for API authentication.", default=None, ) - VIKINGDB_REGION: str = Field( description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", default="cn-shanghai", ) - VIKINGDB_HOST: str = Field( description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ 'api-vikingdb.mlp.cn-shanghai.volces.com')", default="api-vikingdb.mlp.cn-shanghai.volces.com", ) - VIKINGDB_SCHEME: str = Field( description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", default="http", ) - VIKINGDB_CONNECTION_TIMEOUT: int = Field( description="The connection timeout of the Volcengine VikingDB service.", default=30, ) - VIKINGDB_SOCKET_TIMEOUT: int = Field( description="The socket timeout of the Volcengine VikingDB service.", default=30, diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 25000e8bde..f135956d97 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -13,17 +13,14 @@ class WeaviateConfig(BaseSettings): description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')", default=None, ) - WEAVIATE_API_KEY: Optional[str] = Field( description="API key for authenticating with the Weaviate server", default=None, ) - WEAVIATE_GRPC_ENABLED: bool = Field( description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)", default=True, ) - WEAVIATE_BATCH_SIZE: PositiveInt = Field( description="Number of objects to be processed in a single batch operation (default is 100)", default=100, diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 1b88ddcfe6..24e36cd7c8 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -11,39 +11,28 @@ class OTelConfig(BaseSettings): description="Whether to enable OpenTelemetry", default=False, ) - OTLP_BASE_ENDPOINT: str = Field( description="OTLP base endpoint", default="http://localhost:4318", ) - OTLP_API_KEY: str = Field( description="OTLP API key", default="", ) - OTEL_EXPORTER_TYPE: str = Field( description="OTEL exporter type", default="otlp", ) - OTEL_EXPORTER_OTLP_PROTOCOL: str = Field( description="OTLP exporter protocol ('grpc' or 'http')", default="http", ) - OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)") - OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field( default=5000, description="Batch export schedule delay in milliseconds" ) - OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor") - OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size") - OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds") - OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds") - OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds") diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index f02f7dc9ff..559726a220 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -19,17 +19,14 @@ class ApolloSettingsSourceInfo(BaseSettings): description="apollo app_id", default=None, ) - APOLLO_CLUSTER: Optional[str] = Field( description="apollo cluster", default=None, ) - APOLLO_CONFIG_URL: Optional[str] = Field( description="apollo config url", default=None, ) - APOLLO_NAMESPACE: Optional[str] = Field( description="apollo namespace", default=None, diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index 88b30d3987..0e9a7451f5 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -37,13 +37,10 @@ class ApolloClient: self.config_url = config_url self.cluster = cluster self.app_id = app_id - # Non-core parameters self.ip = init_ip() self.secret = secret - # Check the parameter variables - # Private control variables self._cycle_time = 5 self._stopping = False @@ -62,7 +59,6 @@ class ApolloClient: self._path_checker() if start_hot_update: self._start_hot_update() - # start the heartbeat thread heartbeat = threading.Thread(target=self._heart_beat) heartbeat.daemon = True @@ -95,25 +91,21 @@ class ApolloClient: val = get_value_from_dict(namespace_cache, key) if val is not None: return val - no_key = no_key_cache_key(namespace, key) if no_key in self._no_key: return default_val - # read the network configuration namespace_data = self.get_json_from_net(namespace) val = get_value_from_dict(namespace_data, key) if val is not None: self._update_cache_and_file(namespace_data, namespace) return val - # read the file configuration namespace_cache = self._get_local_cache(namespace) val = get_value_from_dict(namespace_cache, key) if val is not None: self._update_cache_and_file(namespace_cache, namespace) return val - # If all of them are not obtained, the default value is returned # and the local cache is set to None self._set_local_cache_none(namespace, key) diff --git a/api/configs/remote_settings_sources/apollo/python_3x.py b/api/configs/remote_settings_sources/apollo/python_3x.py index 6a5f381991..cd22bc1176 100644 --- a/api/configs/remote_settings_sources/apollo/python_3x.py +++ b/api/configs/remote_settings_sources/apollo/python_3x.py @@ -10,12 +10,9 @@ ssl_context = ssl.create_default_context() ssl_context.set_ciphers("HIGH:!DH:!aNULL") ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - # Create an opener object and pass in a custom SSL context opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context)) - urllib.request.install_opener(opener) - logger = logging.getLogger(__name__) diff --git a/api/configs/remote_settings_sources/nacos/__init__.py b/api/configs/remote_settings_sources/nacos/__init__.py index b1ce8e87bc..18826aa923 100644 --- a/api/configs/remote_settings_sources/nacos/__init__.py +++ b/api/configs/remote_settings_sources/nacos/__init__.py @@ -8,7 +8,6 @@ from pydantic.fields import FieldInfo from .http_request import NacosHttpClient logger = logging.getLogger(__name__) - from configs.remote_settings_sources.base import RemoteSettingsSource from .utils import _parse_config @@ -24,7 +23,6 @@ class NacosSettingsSource(RemoteSettingsSource): data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") - params = {"dataId": data_id, "group": group, "tenant": tenant} try: content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params) @@ -44,9 +42,7 @@ class NacosSettingsSource(RemoteSettingsSource): def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: if not isinstance(self.remote_configs, dict): raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") - field_value = self.remote_configs.get(field_name) if field_value is None: return None, field_name, False - return field_value, field_name, False diff --git a/api/configs/remote_settings_sources/nacos/http_request.py b/api/configs/remote_settings_sources/nacos/http_request.py index 9b3359c6ad..09977cf9cc 100644 --- a/api/configs/remote_settings_sources/nacos/http_request.py +++ b/api/configs/remote_settings_sources/nacos/http_request.py @@ -32,12 +32,9 @@ class NacosHttpClient: def _inject_auth_info(self, headers, params, module="config"): headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) - if module == "login": return - ts = str(int(time.time() * 1000)) - if self.ak and self.sk: sign_str = self.get_sign_str(params["group"], params["tenant"], ts) headers["Spas-AccessKey"] = self.ak @@ -67,7 +64,6 @@ class NacosHttpClient: current_time = time.time() if self.token and not force_refresh and self.token_expire_time > current_time: return self.token - params = {"username": self.username, "password": self.password} url = "http://" + self.server + "/nacos/v1/auth/login" try: diff --git a/api/configs/remote_settings_sources/nacos/utils.py b/api/configs/remote_settings_sources/nacos/utils.py index f3372563b1..12061906a4 100644 --- a/api/configs/remote_settings_sources/nacos/utils.py +++ b/api/configs/remote_settings_sources/nacos/utils.py @@ -2,30 +2,23 @@ def _parse_config(self, content: str) -> dict[str, str]: config: dict[str, str] = {} if not content: return config - for line in content.splitlines(): cleaned_line = line.strip() if not cleaned_line or cleaned_line.startswith(("#", "!")): continue - separator_index = -1 for i, c in enumerate(cleaned_line): if c in ("=", ":") and (i == 0 or cleaned_line[i - 1] != "\\"): separator_index = i break - if separator_index == -1: continue - key = cleaned_line[:separator_index].strip() raw_value = cleaned_line[separator_index + 1 :].strip() - try: decoded_value = bytes(raw_value, "utf-8").decode("unicode_escape") decoded_value = decoded_value.replace(r"\=", "=").replace(r"\:", ":") except UnicodeDecodeError: decoded_value = raw_value - config[key] = decoded_value - return config diff --git a/api/constants/__init__.py b/api/constants/__init__.py index a84de0a451..933556f359 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -2,19 +2,13 @@ from configs import dify_config HIDDEN_VALUE = "[__HIDDEN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" - DEFAULT_FILE_NUMBER_LIMITS = 3 - IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) - AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) - - if dify_config.ETL_TYPE == "Unstructured": DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) diff --git a/api/constants/languages.py b/api/constants/languages.py index 1157ec4307..78c5dbb31d 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -20,13 +20,11 @@ language_timezone_mapping = { "sl-SI": "Europe/Ljubljana", "th-TH": "Asia/Bangkok", } - languages = list(language_timezone_mapping.keys()) def supported_language(lang): if lang in languages: return lang - error = "{lang} is not a valid language.".format(lang=lang) raise ValueError(error) diff --git a/api/constants/mimetypes.py b/api/constants/mimetypes.py index 38988cdd24..934968e57e 100644 --- a/api/constants/mimetypes.py +++ b/api/constants/mimetypes.py @@ -1,6 +1,5 @@ # The two constants below should keep in sync. # Default content type for files which have no explicit content type. - DEFAULT_MIME_TYPE = "application/octet-stream" # Default file extension for files which have no explicit content type, should # correspond to the `DEFAULT_MIME_TYPE` above. diff --git a/api/constants/tts_auto_play_timeout.py b/api/constants/tts_auto_play_timeout.py index d5ed30830a..e02c282efb 100644 --- a/api/constants/tts_auto_play_timeout.py +++ b/api/constants/tts_auto_play_timeout.py @@ -1,4 +1,3 @@ TTS_AUTO_PLAY_TIMEOUT = 5 - # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02 diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index ae41a2c03a..1936ff061f 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -9,27 +9,20 @@ if TYPE_CHECKING: from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.workflow.entities.variable_pool import VariablePool - - """ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with """ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( ContextVar("plugin_tool_providers") ) - plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) - plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( ContextVar("plugin_model_providers") ) - plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) - plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) - plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( ContextVar("plugin_model_schemas") ) diff --git a/api/contexts/wrapper.py b/api/contexts/wrapper.py index 8cd53487ef..d1403b2dbc 100644 --- a/api/contexts/wrapper.py +++ b/api/contexts/wrapper.py @@ -15,7 +15,6 @@ class RecyclableContextVar(Generic[T]): """ RecyclableContextVar is a wrapper around ContextVar It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now - NOTE: you need to call `increment_thread_recycles` before requests """ @@ -38,7 +37,6 @@ class RecyclableContextVar(Generic[T]): self_updates = self._updates.get() if thread_recycles > self_updates: self._updates.set(thread_recycles) - # check if thread is recycled and should be updated if thread_recycles < self_updates: return self._context_var.get() @@ -56,10 +54,8 @@ class RecyclableContextVar(Generic[T]): self_updates = self._updates.get() if thread_recycles > self_updates: self._updates.set(thread_recycles) - if self._updates.get() == self._thread_recycles.get(0): # after increment, self._updates.set(self._updates.get() + 1) - # set the context self._context_var.set(value) diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index 8b13789179..e69de29bb2 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1 +0,0 @@ - diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 3466eea1f6..80200d2108 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -9,7 +9,6 @@ parameters__system_parameters = { "file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, } - parameters_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -24,7 +23,6 @@ parameters_fields = { "file_upload": fields.Raw, "system_parameters": fields.Nested(parameters__system_parameters), } - site_fields = { "title": fields.String, "chat_color_theme": fields.String, diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 008f1f0f7a..e5701759b4 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -24,7 +24,6 @@ except ImportError: else: warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) magic = None # type: ignore - from pydantic import BaseModel @@ -41,7 +40,6 @@ def guess_file_info_from_response(response: httpx.Response): parsed_url = urllib.parse.urlparse(url) url_path = parsed_url.path filename = os.path.basename(url_path) - # If filename couldn't be extracted, use Content-Disposition header if not filename: content_disposition = response.headers.get("Content-Disposition") @@ -49,12 +47,10 @@ def guess_file_info_from_response(response: httpx.Response): filename_match = re.search(r'filename="?(.+)"?', content_disposition) if filename_match: filename = filename_match.group(1) - # If still no filename, generate a unique one if not filename: unique_name = str(uuid4()) filename = f"{unique_name}" - # Guess MIME type from filename first, then URL mimetype, _ = mimetypes.guess_type(filename) if mimetype is None: @@ -62,21 +58,17 @@ def guess_file_info_from_response(response: httpx.Response): if mimetype is None: # If guessing fails, use Content-Type from response headers mimetype = response.headers.get("Content-Type", "application/octet-stream") - # Use python-magic to guess MIME type if still unknown or generic if mimetype == "application/octet-stream" and magic is not None: try: mimetype = magic.from_buffer(response.content[:1024], mime=True) except magic.MagicException: pass - extension = os.path.splitext(filename)[1] - # Ensure filename has an extension if not extension: extension = mimetypes.guess_extension(mimetype) or ".bin" filename = f"{filename}{extension}" - return FileInfo( filename=filename, extension=extension, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index dbdcdc46ce..1db91a7057 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -27,21 +27,17 @@ from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) - # File api.add_resource(FileApi, "/files/upload") api.add_resource(FilePreviewApi, "/files//preview") api.add_resource(FileSupportTypeApi, "/files/support-type") - # Remote files api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") - # Import App api.add_resource(AppImportApi, "/apps/imports") api.add_resource(AppImportConfirmApi, "/apps/imports//confirm") api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") - # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version @@ -97,7 +93,6 @@ from .explore import ( # Explore Audio api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") - # Explore Completion api.add_resource( CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" @@ -115,7 +110,6 @@ api.add_resource( "/installed-apps//chat-messages//stop", endpoint="installed_app_stop_chat_completion", ) - # Explore Conversation api.add_resource( ConversationRenameApi, @@ -140,8 +134,6 @@ api.add_resource( "/installed-apps//conversations//unpin", endpoint="installed_app_conversation_unpin", ) - - # Explore Message api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") api.add_resource( @@ -164,7 +156,6 @@ api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/tasks//stop" ) - # Import tag controllers from .tag import tags diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index f5257fae79..7f8e3b8d43 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -19,23 +19,17 @@ def admin_required(view): def decorated(*args, **kwargs): if not dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") - auth_header = request.headers.get("Authorization") if auth_header is None: raise Unauthorized("Authorization header is missing.") - if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") - return view(*args, **kwargs) return decorated @@ -55,11 +49,9 @@ class InsertExploreAppListApi(Resource): parser.add_argument("category", type=str, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") - site = app.site if not site: desc = args["desc"] or "" @@ -71,12 +63,10 @@ class InsertExploreAppListApi(Resource): copy_right = site.copyright or args["copyright"] or "" privacy_policy = site.privacy_policy or args["privacy_policy"] or "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" - with Session(db.engine) as session: recommended_app = session.execute( select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() - if not recommended_app: recommended_app = RecommendedApp( app_id=app.id, @@ -88,12 +78,9 @@ class InsertExploreAppListApi(Resource): category=args["category"], position=args["position"], ) - db.session.add(recommended_app) - app.is_public = True db.session.commit() - return {"result": "success"}, 201 else: recommended_app.description = desc @@ -103,11 +90,8 @@ class InsertExploreAppListApi(Resource): recommended_app.language = args["language"] recommended_app.category = args["category"] recommended_app.position = args["position"] - app.is_public = True - db.session.commit() - return {"result": "success"}, 200 @@ -119,16 +103,12 @@ class InsertExploreAppApi(Resource): recommended_app = session.execute( select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() - if not recommended_app: return {"result": "success"}, 204 - with Session(db.engine) as session: app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() - if app: app.is_public = False - with Session(db.engine) as session: installed_apps = session.execute( select(InstalledApp).filter( @@ -136,13 +116,10 @@ class InsertExploreAppApi(Resource): InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, ) ).all() - for installed_app in installed_apps: db.session.delete(installed_app) - db.session.delete(recommended_app) db.session.commit() - return {"result": "success"}, 204 diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 47c93a15c6..85811b9ac2 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -23,7 +23,6 @@ api_key_fields = { "last_used_at": TimestampField, "created_at": TimestampField, } - api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} @@ -38,16 +37,13 @@ def _get_resource(resource_id, tenant_id, resource_model): resource = session.execute( select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) ).scalar_one_or_none() - if resource is None: flask_restful.abort(404, message=f"{resource_model.__name__} not found.") - return resource class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None resource_model: Any = None resource_id_field: str | None = None @@ -73,20 +69,17 @@ class BaseApiKeyListResource(Resource): _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_editor: raise Forbidden() - current_key_count = ( db.session.query(ApiToken) .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .count() ) - if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", ) - key = ApiToken.generate_api_key(self.token_prefix, 24) api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) @@ -100,7 +93,6 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None resource_model: Any = None resource_id_field: str | None = None @@ -110,11 +102,9 @@ class BaseApiKeyResource(Resource): resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = ( db.session.query(ApiToken) .filter( @@ -124,13 +114,10 @@ class BaseApiKeyResource(Resource): ) .first() ) - if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {"result": "success"}, 204 diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa5..26cc6f2a61 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -17,7 +17,6 @@ class AdvancedPromptTemplateList(Resource): parser.add_argument("has_context", type=str, required=False, default="true", location="args") parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() - return AdvancedPromptTemplateService.get_prompt(args) diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d433415894..e4d5edbcfe 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -19,9 +19,7 @@ class AgentLogApi(Resource): parser = reqparse.RequestParser() parser.add_argument("message_id", type=uuid_value, required=True, location="args") parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") - args = parser.parse_args() - return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 2b48afd550..3030979c85 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -28,7 +28,6 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") @@ -51,7 +50,6 @@ class AppAnnotationSettingDetailApi(Resource): def get(self, app_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) return result, 200 @@ -64,14 +62,11 @@ class AppAnnotationSettingUpdateApi(Resource): def post(self, app_id, annotation_setting_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) - parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) return result, 200 @@ -84,19 +79,16 @@ class AnnotationReplyActionStatusApi(Resource): def get(self, app_id, job_id, action): if not current_user.is_editor: raise Forbidden() - job_id = str(job_id) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") - job_status = cache_result.decode() error_msg = "" if job_status == "error": app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 @@ -107,11 +99,9 @@ class AnnotationListApi(Resource): def get(self, app_id): if not current_user.is_editor: raise Forbidden() - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) - app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { @@ -131,7 +121,6 @@ class AnnotationExportApi(Resource): def get(self, app_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) response = {"data": marshal(annotation_list, annotation_fields)} @@ -147,7 +136,6 @@ class AnnotationCreateApi(Resource): def post(self, app_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) parser = reqparse.RequestParser() parser.add_argument("question", required=True, type=str, location="json") @@ -166,7 +154,6 @@ class AnnotationUpdateDeleteApi(Resource): def post(self, app_id, annotation_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() @@ -182,7 +169,6 @@ class AnnotationUpdateDeleteApi(Resource): def delete(self, app_id, annotation_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) @@ -197,14 +183,12 @@ class AnnotationBatchImportApi(Resource): def post(self, app_id): if not current_user.is_editor: raise Forbidden() - app_id = str(app_id) # get file from request file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() # check file type @@ -221,7 +205,6 @@ class AnnotationBatchImportStatusApi(Resource): def get(self, app_id, job_id): if not current_user.is_editor: raise Forbidden() - job_id = str(job_id) indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) @@ -232,7 +215,6 @@ class AnnotationBatchImportStatusApi(Resource): if job_status == "error": indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 @@ -243,7 +225,6 @@ class AnnotationHitHistoryListApi(Resource): def get(self, app_id, annotation_id): if not current_user.is_editor: raise Forbidden() - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 860166a61a..71bf2a0f0c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -64,25 +64,20 @@ class AppListApi(Resource): parser.add_argument("name", type=str, location="args", required=False) parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) - args = parser.parse_args() - # get app list app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} - if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) if len(res) != len(app_ids): raise BadRequest("Invalid app id in webapp auth") - for app in app_pagination.items: if str(app.id) in res: app.access_mode = res[str(app.id)].access_mode - return marshal(app_pagination, app_pagination_fields), 200 @setup_required @@ -100,17 +95,13 @@ class AppListApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") - app_service = AppService() app = app_service.create_app(current_user.current_tenant_id, args, current_user) - return app, 201 @@ -124,13 +115,10 @@ class AppApi(Resource): def get(self, app_model): """Get app detail""" app_service = AppService() - app_model = app_service.get_app(app_model) - if FeatureService.get_system_features().webapp_auth.enabled: app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model @setup_required @@ -143,7 +131,6 @@ class AppApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("description", type=str, location="json") @@ -152,10 +139,8 @@ class AppApi(Resource): parser.add_argument("icon_background", type=str, location="json") parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") args = parser.parse_args() - app_service = AppService() app_model = app_service.update_app(app_model, args) - return app_model @setup_required @@ -167,10 +152,8 @@ class AppApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - app_service = AppService() app_service.delete_app(app_model) - return {"result": "success"}, 204 @@ -185,7 +168,6 @@ class AppCopyApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") parser.add_argument("description", type=str, location="json") @@ -193,7 +175,6 @@ class AppCopyApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - with Session(db.engine) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) @@ -209,10 +190,8 @@ class AppCopyApi(Resource): icon_background=args.get("icon_background"), ) session.commit() - stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 @@ -226,12 +205,10 @@ class AppExportApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - # Add include_secret params parser = reqparse.RequestParser() parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} @@ -245,14 +222,11 @@ class AppNameApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - app_service = AppService() app_model = app_service.update_app_name(app_model, args.get("name")) - return app_model @@ -266,15 +240,12 @@ class AppIconApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - app_service = AppService() app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) - return app_model @@ -288,14 +259,11 @@ class AppSiteStatus(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() - app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) - return app_model @@ -309,14 +277,11 @@ class AppApiStatus(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() - app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) - return app_model @@ -327,7 +292,6 @@ class AppTraceApi(Resource): def get(self, app_id): """Get app trace""" app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) - return app_trace_config @setup_required @@ -341,13 +305,11 @@ class AppTraceApi(Resource): parser.add_argument("enabled", type=bool, required=True, location="json") parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() - OpsTraceManager.update_app_tracing_config( app_id=app_id, enabled=args["enabled"], tracing_provider=args["tracing_provider"], ) - return {"result": "success"} diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 9ffb94e9f9..b6055adfa5 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -31,7 +31,6 @@ class AppImportApi(Resource): # Check user role first if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("mode", type=str, required=True, location="json") parser.add_argument("yaml_content", type=str, location="json") @@ -43,7 +42,6 @@ class AppImportApi(Resource): parser.add_argument("icon_background", type=str, location="json") parser.add_argument("app_id", type=str, location="json") args = parser.parse_args() - # Create service with session with Session(db.engine) as session: import_service = AppDslService(session) @@ -83,7 +81,6 @@ class AppImportConfirmApi(Resource): # Check user role first if not current_user.is_editor: raise Forbidden() - # Create service with session with Session(db.engine) as session: import_service = AppDslService(session) @@ -91,7 +88,6 @@ class AppImportConfirmApi(Resource): account = cast(Account, current_user) result = import_service.confirm_import(import_id=import_id, account=account) session.commit() - # Return appropriate status code based on result if result.status == ImportStatus.FAILED.value: return result.model_dump(mode="json"), 400 @@ -107,9 +103,7 @@ class AppImportCheckDependenciesApi(Resource): def get(self, app_model: App): if not current_user.is_editor: raise Forbidden() - with Session(db.engine) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) - return result.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 665cf1aede..57ff34e5ba 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -39,14 +39,12 @@ class ChatMessageAudioApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): file = request.files["file"] - try: response = AudioService.transcript_asr( app_model=app_model, file=file, end_user=None, ) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -87,11 +85,9 @@ class ChatMessageTextApi(Resource): parser.add_argument("voice", type=str, location="json") parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get("message_id", None) text = args.get("text", None) voice = args.get("voice", None) - response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True ) @@ -132,12 +128,10 @@ class TextModesApi(Resource): parser = reqparse.RequestParser() parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() - response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, language=args["language"], ) - return response except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError: raise AppUnavailableError("Text to audio voices language parameter loss.") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 732f5b799a..d6ce091f4b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -48,17 +48,13 @@ class CompletionMessageApi(Resource): parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: response = AppGenerateService.generate( app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -89,9 +85,7 @@ class CompletionMessageStopApi(Resource): @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): account = flask_login.current_user - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {"result": "success"}, 200 @@ -111,17 +105,13 @@ class ChatMessageApi(Resource): parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: response = AppGenerateService.generate( app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -154,9 +144,7 @@ class ChatMessageStopApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model, task_id): account = flask_login.current_user - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {"result": "success"}, 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 70d6216497..2c8a58c453 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -44,9 +44,7 @@ class CompletionConversationApi(Resource): parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args["keyword"]: query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( @@ -54,29 +52,21 @@ class CompletionConversationApi(Resource): Message.answer.ilike("%{}%".format(args["keyword"])), ) ) - account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at >= start_datetime_utc) - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at < end_datetime_utc) - # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore @@ -88,11 +78,8 @@ class CompletionConversationApi(Resource): .group_by(Conversation.id) .having(func.count(MessageAnnotation.id) == 0) ) - query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) - return conversations @@ -106,7 +93,6 @@ class CompletionConversationDetailApi(Resource): if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - return _get_conversation(app_model, conversation_id) @setup_required @@ -117,19 +103,15 @@ class CompletionConversationDetailApi(Resource): if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - conversation = ( db.session.query(Conversation) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) - if not conversation: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True db.session.commit() - return {"result": "success"}, 204 @@ -161,7 +143,6 @@ class ChatConversationApi(Resource): location="args", ) args = parser.parse_args() - subquery = ( db.session.query( Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") @@ -169,9 +150,7 @@ class ChatConversationApi(Resource): .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) - query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args["keyword"]: keyword_filter = "%{}%".format(args["keyword"]) query = ( @@ -191,37 +170,29 @@ class ChatConversationApi(Resource): ) .group_by(Conversation.id) ) - account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - match args["sort_by"]: case "updated_at" | "-updated_at": query = query.where(Conversation.updated_at >= start_datetime_utc) case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at >= start_datetime_utc) - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - match args["sort_by"]: case "updated_at" | "-updated_at": query = query.where(Conversation.updated_at <= end_datetime_utc) case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id @@ -232,7 +203,6 @@ class ChatConversationApi(Resource): .group_by(Conversation.id) .having(func.count(MessageAnnotation.id) == 0) ) - if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) # type: ignore @@ -240,10 +210,8 @@ class ChatConversationApi(Resource): .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) ) - if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - match args["sort_by"]: case "created_at": query = query.order_by(Conversation.created_at.asc()) @@ -255,9 +223,7 @@ class ChatConversationApi(Resource): query = query.order_by(Conversation.updated_at.desc()) case _: query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) - return conversations @@ -271,7 +237,6 @@ class ChatConversationDetailApi(Resource): if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - return _get_conversation(app_model, conversation_id) @setup_required @@ -282,19 +247,15 @@ class ChatConversationDetailApi(Resource): if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) - conversation = ( db.session.query(Conversation) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) - if not conversation: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True db.session.commit() - return {"result": "success"}, 204 @@ -310,13 +271,10 @@ def _get_conversation(app_model, conversation_id): .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) - if not conversation: raise NotFound("Conversation Not Exists.") - if not conversation.read_at: conversation.read_at = datetime.now(UTC).replace(tzinfo=None) conversation.read_account_id = current_user.id db.session.commit() - return conversation diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1..719ee4ab02 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -22,7 +22,6 @@ class ConversationVariablesApi(Resource): parser = reqparse.RequestParser() parser.add_argument("conversation_id", type=str, location="args") args = parser.parse_args() - stmt = ( select(ConversationVariable) .where(ConversationVariable.app_id == app_model.id) @@ -32,15 +31,12 @@ class ConversationVariablesApi(Resource): stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) else: raise ValueError("conversation_id is required") - # NOTE: This is a temporary solution to avoid performance issues. page = 1 page_size = 100 stmt = stmt.limit(page_size).offset((page - 1) * page_size) - with Session(db.engine) as session: rows = session.scalars(stmt).all() - return { "page": page, "limit": page_size, diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 790369c052..2d6ae9281d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -27,10 +27,8 @@ class RuleGenerateApi(Resource): parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() - account = current_user PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) - try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, @@ -47,7 +45,6 @@ class RuleGenerateApi(Resource): raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - return rules @@ -62,7 +59,6 @@ class RuleCodeGenerateApi(Resource): parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") args = parser.parse_args() - account = current_user CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024")) try: @@ -81,7 +77,6 @@ class RuleCodeGenerateApi(Resource): raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - return code_result @@ -94,7 +89,6 @@ class RuleStructuredOutputGenerateApi(Resource): parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - account = current_user try: structured_output = LLMGenerator.generate_structured_output( @@ -110,7 +104,6 @@ class RuleStructuredOutputGenerateApi(Resource): raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - return structured_output diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a15..29ca8efb4e 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -52,26 +52,21 @@ class ChatMessageListApi(Resource): parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = ( db.session.query(Conversation) .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .first() ) - if not conversation: raise NotFound("Conversation Not Exists.") - if args["first_id"]: first_message = ( db.session.query(Message) .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .first() ) - if not first_message: raise NotFound("First message not found") - history_messages = ( db.session.query(Message) .filter( @@ -91,7 +86,6 @@ class ChatMessageListApi(Resource): .limit(args["limit"]) .all() ) - has_more = False if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] @@ -104,12 +98,9 @@ class ChatMessageListApi(Resource): ) .count() ) - if rest_count > 0: has_more = True - history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) @@ -123,16 +114,11 @@ class MessageFeedbackApi(Resource): parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args["message_id"]) - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() - if not message: raise NotFound("Message Not Exists.") - feedback = message.admin_feedback - if not args["rating"] and feedback: db.session.delete(feedback) elif args["rating"] and feedback: @@ -149,9 +135,7 @@ class MessageFeedbackApi(Resource): from_account_id=current_user.id, ) db.session.add(feedback) - db.session.commit() - return {"result": "success"} @@ -165,7 +149,6 @@ class MessageAnnotationApi(Resource): def post(self, app_model): if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("message_id", required=False, type=uuid_value, location="json") parser.add_argument("question", required=True, type=str, location="json") @@ -173,7 +156,6 @@ class MessageAnnotationApi(Resource): parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) - return annotation @@ -184,7 +166,6 @@ class MessageAnnotationCountApi(Resource): @get_app_model def get(self, app_model): count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {"count": count} @@ -195,7 +176,6 @@ class MessageSuggestedQuestionApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model, message_id): message_id = str(message_id) - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER @@ -217,7 +197,6 @@ class MessageSuggestedQuestionApi(Resource): except Exception: logging.exception("internal server error.") raise InternalServerError() - return {"data": questions} @@ -229,12 +208,9 @@ class MessageApi(Resource): @marshal_with(message_detail_fields) def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() - if not message: raise NotFound("Message Not Exists.") - return message diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f30e3e893c..dc3cb14eed 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -31,14 +31,12 @@ class ModelConfigResource(Resource): config=cast(dict, request.json), app_mode=AppMode.value_of(app_model.mode), ) - new_app_model_config = AppModelConfig( app_id=app_model.id, created_by=current_user.id, updated_by=current_user.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config original_app_model_config = ( @@ -54,7 +52,6 @@ class ModelConfigResource(Resource): for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) # get tool try: @@ -72,7 +69,6 @@ class ModelConfigResource(Resource): ) except Exception: continue - # get decrypted parameters if agent_tool_entity.tool_parameters: parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) @@ -80,17 +76,14 @@ class ModelConfigResource(Resource): else: parameters = {} masked_parameter = {} - key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime - # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) - # get tool key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: @@ -104,7 +97,6 @@ class ModelConfigResource(Resource): ) except Exception: continue - manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, tool_runtime=tool_runtime, @@ -113,34 +105,26 @@ class ModelConfigResource(Resource): identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() - # override parameters if it equals to masked parameters if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - for masked_key, masked_value in masked_parameter_map[key].items(): if ( masked_key in agent_tool_entity.tool_parameters and agent_tool_entity.tool_parameters[masked_key] == masked_value ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) - # encrypt parameters if agent_tool_entity.tool_parameters: tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) - db.session.add(new_app_model_config) db.session.flush() - app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {"result": "success"} diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 978c02412c..9301382392 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -20,7 +20,6 @@ class TraceAppConfigApi(Resource): parser = reqparse.RequestParser() parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() - try: trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: @@ -38,7 +37,6 @@ class TraceAppConfigApi(Resource): parser.add_argument("tracing_provider", type=str, required=True, location="json") parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() - try: result = OpsService.create_tracing_app_config( app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] @@ -60,7 +58,6 @@ class TraceAppConfigApi(Resource): parser.add_argument("tracing_provider", type=str, required=True, location="json") parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() - try: result = OpsService.update_tracing_app_config( app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] @@ -79,7 +76,6 @@ class TraceAppConfigApi(Resource): parser = reqparse.RequestParser() parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() - try: result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 3c3a359eeb..09f3423488 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -45,15 +45,12 @@ class AppSite(Resource): @marshal_with(app_site_fields) def post(self, app_model): args = parse_app_site_args() - # The role of the current user in the ta table must be editor, admin, or owner if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() if not site: raise NotFound - for attr_name in [ "title", "icon_type", @@ -75,11 +72,9 @@ class AppSite(Resource): value = args.get(attr_name) if value is not None: setattr(site, attr_name, value) - site.updated_by = current_user.id site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return site @@ -93,17 +88,13 @@ class AppSiteAccessTokenReset(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() - if not site: raise NotFound - site.code = Site.generate_code(16) site.updated_by = current_user.id site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return site diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 86aed77412..649e5ee04e 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -22,12 +22,10 @@ class DailyMessageStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(*) AS message_count @@ -36,39 +34,28 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "message_count": i.message_count}) - return jsonify({"data": response_data}) @@ -79,12 +66,10 @@ class DailyConversationStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(DISTINCT messages.conversation_id) AS conversation_count @@ -93,39 +78,28 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({"data": response_data}) @@ -136,12 +110,10 @@ class DailyTerminalsStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(DISTINCT messages.from_end_user_id) AS terminal_count @@ -150,39 +122,28 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({"data": response_data}) @@ -193,12 +154,10 @@ class DailyTokenCostStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, @@ -208,41 +167,30 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} ) - return jsonify({"data": response_data}) @@ -253,12 +201,10 @@ class AverageSessionInteractionStatistic(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(subquery.message_count) AS interactions @@ -275,30 +221,22 @@ FROM WHERE c.app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND c.created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND c.created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += """ GROUP BY m.conversation_id ) subquery @@ -309,16 +247,13 @@ GROUP BY date ORDER BY date""" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} ) - return jsonify({"data": response_data}) @@ -329,12 +264,10 @@ class UserSatisfactionRateStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(m.id) AS message_count, @@ -347,34 +280,24 @@ LEFT JOIN WHERE m.app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND m.created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND m.created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: @@ -384,7 +307,6 @@ WHERE "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), } ) - return jsonify({"data": response_data}) @@ -395,12 +317,10 @@ class AverageResponseTimeStatistic(Resource): @get_app_model(mode=AppMode.COMPLETION) def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) AS latency @@ -409,39 +329,28 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({"data": response_data}) @@ -452,12 +361,10 @@ class TokensPerSecondStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE @@ -469,39 +376,28 @@ FROM WHERE app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) - return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a9f088a276..dbe9dfcb67 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -47,7 +47,6 @@ logger = logging.getLogger(__name__) # of concerns and make the code more maintainable. def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: files = files or [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_objs: Sequence[File] = [] if file_extra_config is None: @@ -73,14 +72,11 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - # fetch draft workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_draft_workflow(app_model=app_model) - if not workflow: raise DraftWorkflowNotExist() - # return workflow, if not found, return None (initiate graph by frontend) return workflow @@ -95,9 +91,7 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - content_type = request.headers.get("Content-Type", "") - if "application/json" in content_type: parser = reqparse.RequestParser() parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") @@ -111,10 +105,8 @@ class DraftWorkflowApi(Resource): data = json.loads(request.data.decode("utf-8")) if "graph" not in data or "features" not in data: raise ValueError("graph or features not found in data") - if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): raise ValueError("graph or features is not a dict") - args = { "graph": data.get("graph"), "features": data.get("features"), @@ -126,12 +118,9 @@ class DraftWorkflowApi(Resource): return {"message": "Invalid JSON data"}, 400 else: abort(415) - if not isinstance(current_user, Account): raise Forbidden() - workflow_service = WorkflowService() - try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ @@ -152,7 +141,6 @@ class DraftWorkflowApi(Resource): ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() - return { "result": "success", "hash": workflow.unique_hash, @@ -172,24 +160,19 @@ class AdvancedChatDraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") parser.add_argument("query", type=str, required=True, location="json", default="") parser.add_argument("files", type=list, location="json") parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - args = parser.parse_args() - try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -216,19 +199,15 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate_single_iteration( app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -253,19 +232,15 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate_single_iteration( app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -290,19 +265,15 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate_single_loop( app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -327,19 +298,15 @@ class WorkflowDraftRunLoopNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate_single_loop( app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -364,15 +331,12 @@ class DraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate( app_model=app_model, @@ -381,7 +345,6 @@ class DraftWorkflowRunApi(Resource): invoke_from=InvokeFrom.DEBUGGER, streaming=True, ) - return helper.compact_generate_response(response) except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) @@ -399,9 +362,7 @@ class WorkflowTaskStopApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return {"result": "success"} @@ -418,20 +379,16 @@ class DraftWorkflowNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("query", type=str, required=False, location="json", default="") parser.add_argument("files", type=list, location="json", default=[]) args = parser.parse_args() - user_inputs = args.get("inputs") if user_inputs is None: raise ValueError("missing inputs") - workflow_srv = WorkflowService() # fetch draft workflow by app_model draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model) @@ -439,7 +396,6 @@ class DraftWorkflowNodeRunApi(Resource): raise ValueError("Workflow not initialized") files = _parse_file(draft_workflow, args.get("files")) workflow_service = WorkflowService() - workflow_node_execution = workflow_service.run_draft_workflow_node( app_model=app_model, draft_workflow=draft_workflow, @@ -449,7 +405,6 @@ class DraftWorkflowNodeRunApi(Resource): query=args.get("query", ""), files=files, ) - return workflow_node_execution @@ -466,11 +421,9 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) - # return workflow, if not found, return None return workflow @@ -485,21 +438,17 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, default="", location="json") parser.add_argument("marked_comment", type=str, required=False, default="", location="json") args = parser.parse_args() - # Validate name and comment length if args.marked_name and len(args.marked_name) > 20: raise ValueError("Marked name cannot exceed 20 characters") if args.marked_comment and len(args.marked_comment) > 100: raise ValueError("Marked comment cannot exceed 100 characters") - workflow_service = WorkflowService() with Session(db.engine) as session: workflow = workflow_service.publish_workflow( @@ -509,14 +458,10 @@ class PublishedWorkflowApi(Resource): marked_name=args.marked_name or "", marked_comment=args.marked_comment or "", ) - app_model.workflow_id = workflow.id db.session.commit() - workflow_created_at = TimestampField().format(workflow.created_at) - session.commit() - return { "result": "success", "created_at": workflow_created_at, @@ -535,7 +480,6 @@ class DefaultBlockConfigsApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -553,23 +497,18 @@ class DefaultBlockConfigApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("q", type=str, location="args") args = parser.parse_args() - q = args.get("q") - filters = None if q: try: filters = json.loads(args.get("q", "")) except json.JSONDecodeError: raise ValueError("Invalid filters") - # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_config(node_type=block_type, filters=filters) @@ -589,10 +528,8 @@ class ConvertToWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - if request.data: parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, nullable=True, location="json") @@ -602,11 +539,9 @@ class ConvertToWorkflowApi(Resource): args = parser.parse_args() else: args = {} - # convert to workflow mode workflow_service = WorkflowService() new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) - # return app id return { "new_app_id": new_app_model.id, @@ -638,7 +573,6 @@ class PublishedAllWorkflowApi(Resource): """ if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") @@ -649,12 +583,10 @@ class PublishedAllWorkflowApi(Resource): limit = int(args.get("limit", 10)) user_id = args.get("user_id") named_only = args.get("named_only", False) - if user_id: if user_id != current_user.id: raise Forbidden() user_id = cast(str, user_id) - workflow_service = WorkflowService() with Session(db.engine) as session: workflows, has_more = workflow_service.get_all_published_workflow( @@ -665,7 +597,6 @@ class PublishedAllWorkflowApi(Resource): user_id=user_id, named_only=named_only, ) - return { "items": workflows, "page": page, @@ -687,34 +618,27 @@ class WorkflowByIdApi(Resource): # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, location="json") parser.add_argument("marked_comment", type=str, required=False, location="json") args = parser.parse_args() - # Validate name and comment length if args.marked_name and len(args.marked_name) > 20: raise ValueError("Marked name cannot exceed 20 characters") if args.marked_comment and len(args.marked_comment) > 100: raise ValueError("Marked comment cannot exceed 100 characters") args = parser.parse_args() - # Prepare update data update_data = {} if args.get("marked_name") is not None: update_data["marked_name"] = args["marked_name"] if args.get("marked_comment") is not None: update_data["marked_comment"] = args["marked_comment"] - if not update_data: return {"message": "No valid fields to update"}, 400 - workflow_service = WorkflowService() - # Create a session and manage the transaction with Session(db.engine, expire_on_commit=False) as session: workflow = workflow_service.update_workflow( @@ -724,13 +648,10 @@ class WorkflowByIdApi(Resource): account_id=current_user.id, data=update_data, ) - if not workflow: raise NotFound("Workflow not found") - # Commit the transaction in the controller session.commit() - return workflow @setup_required @@ -744,12 +665,9 @@ class WorkflowByIdApi(Resource): # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() - workflow_service = WorkflowService() - # Create a session and manage the transaction with Session(db.engine) as session: try: @@ -764,7 +682,6 @@ class WorkflowByIdApi(Resource): abort(400, description=str(e)) except ValueError as e: raise NotFound(str(e)) - return None, 204 diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 310146a5e7..7cb23b4170 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -51,14 +51,11 @@ class WorkflowAppLogApi(Resource): parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: args.created_at__before = isoparse(args.created_at__before) - if args.created_at__after: args.created_at__after = isoparse(args.created_at__after) - # get paginate workflow app logs workflow_app_service = WorkflowAppService() with Session(db.engine) as session: @@ -74,7 +71,6 @@ class WorkflowAppLogApi(Resource): created_by_end_user_session_id=args.created_by_end_user_session_id, created_by_account=args.created_by_account, ) - return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cbf..babb8939dd 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -78,12 +78,10 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } - _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, value=fields.Raw(attribute=_serialize_var_value), ) - _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda _: "env"), @@ -94,7 +92,6 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } - _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), } @@ -108,7 +105,6 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), "total": fields.Raw(), } - _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), } @@ -116,9 +112,7 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { def _api_prerequisite(f): """Common prerequisites for all draft workflow variable APIs. - It ensures the following conditions are satisfied: - - Dify has been property setup. - The request user has logged in and initialized. - The requested app is a workflow or a chat flow. @@ -146,13 +140,11 @@ class WorkflowVariableCollectionApi(Resource): """ parser = _create_pagination_parser() args = parser.parse_args() - # fetch draft workflow by app_model workflow_service = WorkflowService() workflow_exist = workflow_service.is_workflow_exist(app_model=app_model) if not workflow_exist: raise DraftWorkflowNotExist() - # fetch draft workflow by app_model with Session(bind=db.engine, expire_on_commit=False) as session: draft_var_srv = WorkflowDraftVariableService( @@ -163,7 +155,6 @@ class WorkflowVariableCollectionApi(Resource): page=args.page, limit=args.limit, ) - return workflow_vars @_api_prerequisite @@ -187,7 +178,6 @@ def validate_node_id(node_id: str) -> NoReturn | None: # we mitigate the risk that user of the API depending on the implementation detail of the API. # # ref: [Hyrum's Law](https://www.hyrumslaw.com/) - raise InvalidArgumentError( f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", ) @@ -204,7 +194,6 @@ class NodeVariableCollectionApi(Resource): session=session, ) node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) - return node_vars @_api_prerequisite @@ -256,28 +245,23 @@ class VariableApi(Resource): # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = reqparse.RequestParser() parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") # Parse 'value' field as-is to maintain its original data structure parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") - draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) args = parser.parse_args(strict=True) - variable = draft_var_srv.get_variable(variable_id=variable_id) if variable is None: raise NotFoundError(description=f"variable not found, id={variable_id}") if variable.app_id != app_model.id: raise NotFoundError(description=f"variable not found, id={variable_id}") - new_name = args.get(self._PATCH_NAME_FIELD, None) raw_value = args.get(self._PATCH_VALUE_FIELD, None) if new_name is None and raw_value is None: return variable - new_value = None if raw_value is not None: if variable.value_type == SegmentType.FILE: @@ -316,7 +300,6 @@ class VariableResetApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - workflow_srv = WorkflowService() draft_workflow = workflow_srv.get_draft_workflow(app_model) if draft_workflow is None: @@ -328,7 +311,6 @@ class VariableResetApi(Resource): raise NotFoundError(description=f"variable not found, id={variable_id}") if variable.app_id != app_model.id: raise NotFoundError(description=f"variable not found, id={variable_id}") - resetted = draft_var_srv.reset_variable(draft_workflow, variable) db.session.commit() if resetted is None: @@ -385,7 +367,6 @@ class EnvironmentVariableCollectionApi(Resource): workflow = workflow_service.get_draft_workflow(app_model=app_model) if workflow is None: raise DraftWorkflowNotExist() - env_vars = workflow.environment_variables env_vars_list = [] for v in env_vars: @@ -404,7 +385,6 @@ class EnvironmentVariableCollectionApi(Resource): "editable": True, } ) - return {"items": env_vars_list} @@ -415,7 +395,6 @@ api.add_resource( api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") api.add_resource(VariableApi, "/apps//workflows/draft/variables/") api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") - api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9099700213..f4cebf2d12 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -33,10 +33,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource): parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - workflow_run_service = WorkflowRunService() result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) - return result @@ -54,10 +52,8 @@ class WorkflowRunListApi(Resource): parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - workflow_run_service = WorkflowRunService() result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) - return result @@ -72,10 +68,8 @@ class WorkflowRunDetailApi(Resource): Get workflow run detail """ run_id = str(run_id) - workflow_run_service = WorkflowRunService() workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) - return workflow_run @@ -90,7 +84,6 @@ class WorkflowRunNodeExecutionListApi(Resource): Get workflow run node execution list """ run_id = str(run_id) - workflow_run_service = WorkflowRunService() user = cast("Account | EndUser", current_user) node_executions = workflow_run_service.get_workflow_run_node_executions( @@ -98,7 +91,6 @@ class WorkflowRunNodeExecutionListApi(Resource): run_id=run_id, user=user, ) - return {"data": node_executions} diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707b..fe8fb7fc13 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -23,12 +23,10 @@ class WorkflowDailyRunsStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(id) AS runs @@ -42,39 +40,28 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, } - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "runs": i.runs}) - return jsonify({"data": response_data}) @@ -85,12 +72,10 @@ class WorkflowDailyTerminalsStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(DISTINCT workflow_runs.created_by) AS terminal_count @@ -104,39 +89,28 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, } - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({"data": response_data}) @@ -147,12 +121,10 @@ class WorkflowDailyTokenCostStatistic(Resource): @get_app_model def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) AS token_count @@ -166,34 +138,24 @@ WHERE "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, } - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at >= :start" arg_dict["start"] = start_datetime_utc - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += " AND created_at < :end" arg_dict["end"] = end_datetime_utc - sql_query += " GROUP BY date ORDER BY date" - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: @@ -203,7 +165,6 @@ WHERE "token_count": i.token_count, } ) - return jsonify({"data": response_data}) @@ -214,12 +175,10 @@ class WorkflowAverageAppInteractionStatistic(Resource): @get_app_model(mode=[AppMode.WORKFLOW]) def get(self, app_model): account = current_user - parser = reqparse.RequestParser() parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT AVG(sub.interactions) AS interactions, sub.date @@ -246,43 +205,33 @@ GROUP BY "app_id": app_model.id, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, } - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") arg_dict["start"] = start_datetime_utc else: sql_query = sql_query.replace("{{start}}", "") - if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end") arg_dict["end"] = end_datetime_utc else: sql_query = sql_query.replace("{{end}}", "") - response_data = [] - with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} ) - return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 03b60610aa..ec5952ab01 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -23,33 +23,24 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ def decorated_view(*args, **kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") - app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs["app_id"] - app_model = _load_app_model(app_id) - if not app_model: raise AppNotFoundError() - app_mode = AppMode.value_of(app_model.mode) if app_mode == AppMode.CHANNEL: raise AppNotFoundError() - if mode is not None: if isinstance(mode, list): modes = mode else: modes = [mode] - if app_mode not in modes: mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs["app_model"] = app_model - return view_func(*args, **kwargs) return decorated_view diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 1795563ff7..db1b45ead7 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -19,11 +19,9 @@ class ActivateCheckApi(Resource): parser.add_argument("email", type=email, required=False, nullable=True, location="args") parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args["workspace_id"] reg_email = args["email"] token = args["token"] - invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) @@ -51,25 +49,19 @@ class ActivateApi(Resource): ) parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - account = invitation["account"] account.name = args["name"] - account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() - token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index b8c3c8f012..26c554b417 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -62,9 +62,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): # The role of the current user in the table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {"result": "success"}, 204 diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 4c9697cc32..7c0661e913 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -21,7 +21,6 @@ def get_oauth_providers(): client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) - OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -56,11 +55,9 @@ class OAuthDataSourceCallback(Resource): return {"error": "Invalid provider"}, 400 if "code" in request.args: code = request.args.get("code") - return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") elif "error" in request.args: error = request.args.get("error") - return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") @@ -84,7 +81,6 @@ class OAuthDataSourceBinding(Resource): f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" ) return {"error": "OAuth data source process failed"}, 400 - return {"result": "success"}, 200 @@ -105,7 +101,6 @@ class OAuthDataSourceSync(Resource): except requests.exceptions.HTTPError as e: logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") return {"error": "OAuth data source process failed"}, 400 - return {"result": "success"}, 200 diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 3bbe3177fc..3b9b1e015b 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -36,16 +36,13 @@ class ForgotPasswordSendEmailApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: language = "en-US" - with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() token = None @@ -57,7 +54,6 @@ class ForgotPasswordSendEmailApi(Resource): raise AccountNotFound() else: token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) - return {"result": "success", "data": token} @@ -70,32 +66,24 @@ class ForgotPasswordCheckApi(Resource): parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - user_email = args["email"] - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): AccountService.add_forgot_password_error_rate_limit(args["email"]) raise EmailCodeError() - # Verified, revoke the first token AccountService.revoke_reset_password_token(args["token"]) - # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( user_email, code=args["code"], additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -109,11 +97,9 @@ class ForgotPasswordResetApi(Resource): parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - # Validate passwords match if args["new_password"] != args["password_confirm"]: raise PasswordMismatchError() - # Validate token and get reset data reset_data = AccountService.get_reset_password_data(args["token"]) if not reset_data: @@ -121,24 +107,18 @@ class ForgotPasswordResetApi(Resource): # Must use token in reset phase if reset_data.get("phase", "") != "reset": raise InvalidTokenError() - # Revoke token to prevent reuse AccountService.revoke_reset_password_token(args["token"]) - # Generate secure salt and hash password salt = secrets.token_bytes(16) password_hashed = hash_password(args["new_password"], salt) - email = reset_data.get("email", "") - with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account: self._update_existing_account(account, password_hashed, salt, session) else: self._create_new_account(email, args["password_confirm"]) - return {"result": "success"} def _update_existing_account(self, account, password_hashed, salt, session): @@ -146,7 +126,6 @@ class ForgotPasswordResetApi(Resource): account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() session.commit() - # Create workspace if needed if ( not TenantService.get_join_tenants(account) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 5f2a24322d..1af22ce5a8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -49,23 +49,18 @@ class LoginApi(Resource): parser.add_argument("invite_token", type=str, required=False, default=None, location="json") parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - invitation = args["invite_token"] if invitation: invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: language = "en-US" - try: if invitation: data = invitation.get("data", {}) @@ -90,7 +85,6 @@ class LoginApi(Resource): tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: system_features = FeatureService.get_system_features() - if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available(): raise WorkspacesLimitExceeded() else: @@ -98,7 +92,6 @@ class LoginApi(Resource): "result": "fail", "data": "workspace not found, please contact system admin to invite you to join in a workspace", } - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token_pair.model_dump()} @@ -123,7 +116,6 @@ class ResetPasswordSendEmailApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: @@ -139,7 +131,6 @@ class ResetPasswordSendEmailApi(Resource): raise AccountNotFound() else: token = AccountService.send_reset_password_email(account=account, language=language) - return {"result": "success", "data": token} @@ -150,11 +141,9 @@ class EmailCodeLoginSendEmailApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: @@ -163,7 +152,6 @@ class EmailCodeLoginSendEmailApi(Resource): account = AccountService.get_user_through_email(args["email"]) except AccountRegisterError as are: raise AccountInFreezeError() - if account is None: if FeatureService.get_system_features().is_allow_register: token = AccountService.send_email_code_login_email(email=args["email"], language=language) @@ -171,7 +159,6 @@ class EmailCodeLoginSendEmailApi(Resource): raise AccountNotFound() else: token = AccountService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} @@ -183,19 +170,14 @@ class EmailCodeLoginApi(Resource): parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, location="json") args = parser.parse_args() - user_email = args["email"] - token_data = AccountService.get_email_code_login_data(args["token"]) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: raise InvalidEmailError() - if token_data["code"] != args["code"]: raise EmailCodeError() - AccountService.revoke_email_code_login_token(args["token"]) try: account = AccountService.get_user_through_email(user_email) @@ -214,7 +196,6 @@ class EmailCodeLoginApi(Resource): TenantService.create_tenant_member(new_tenant, account, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) - if account is None: try: account = AccountService.create_account_and_tenant( @@ -236,7 +217,6 @@ class RefreshTokenApi(Resource): parser = reqparse.RequestParser() parser.add_argument("refresh_token", type=str, required=True, location="json") args = parser.parse_args() - try: new_token_pair = AccountService.refresh_token(args["refresh_token"]) return {"result": "success", "data": new_token_pair.model_dump()} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..1d2b1ea350 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -43,7 +43,6 @@ def get_oauth_providers(): client_secret=dify_config.GOOGLE_CLIENT_SECRET, redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS @@ -56,7 +55,6 @@ class OAuthLogin(Resource): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) return redirect(auth_url) @@ -68,13 +66,11 @@ class OAuthCallback(Resource): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: return {"error": "Invalid provider"}, 400 - code = request.args.get("code") state = request.args.get("state") invite_token = None if state: invite_token = state - try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -82,16 +78,13 @@ class OAuthCallback(Resource): error_text = e.response.text if e.response else str(e) logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 - if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService._get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) if invitation_email != user_info.email: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") - return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") - try: account = _generate_account(provider, user_info) except AccountNotFoundError: @@ -103,16 +96,13 @@ class OAuthCallback(Resource): ) except AccountRegisterError as e: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") - # Check account status if account.status == AccountStatus.BANNED.value: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") - if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - try: TenantService.create_owner_tenant_if_not_exist(account) except Unauthorized: @@ -122,12 +112,10 @@ class OAuthCallback(Resource): f"{dify_config.CONSOLE_WEB_URL}/signin" "?message=Workspace not found, please contact system admin to invite you to join in a workspace." ) - token_pair = AccountService.login( account=account, ip_address=extract_remote_ip(request), ) - return redirect( f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" ) @@ -135,18 +123,15 @@ class OAuthCallback(Resource): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: account: Optional[Account] = Account.get_by_openid(provider, user_info.id) - if not account: with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() - return account def _generate_account(provider: str, user_info: OAuthUserInfo): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) - if account: tenants = TenantService.get_join_tenants(account) if not tenants: @@ -157,7 +142,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): TenantService.create_tenant_member(new_tenant, account, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) - if not account: if not FeatureService.get_system_features().is_allow_register: raise AccountNotFoundError() @@ -165,7 +149,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) - # Set interface language preferred_lang = request.accept_languages.best_match(languages) if preferred_lang and preferred_lang in languages: @@ -174,10 +157,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): interface_language = languages[0] account.interface_language = interface_language db.session.commit() - # Link account AccountService.link_account_integrate(provider, user_info.id, account) - return account diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c..7fb568ff44 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -17,9 +17,7 @@ class Subscription(Resource): parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() - BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription( args["plan"], args["interval"], current_user.email, current_user.current_tenant_id ) diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 9679632ac7..6af40489a1 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -19,10 +19,8 @@ class ComplianceApi(Resource): parser = reqparse.RequestParser() parser.add_argument("doc_name", type=str, required=True, location="args") args = parser.parse_args() - ip_address = extract_remote_ip(request) device_info = request.headers.get("User-Agent", "Unknown device") - return BillingService.get_compliance_download_link( doc_name=args.doc_name, account_id=current_user.id, diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 7b0d9373cf..77e1e7d7a4 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -36,11 +36,9 @@ class DataSourceApi(Resource): ) .all() ) - base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] - integrate_data = [] for provider in providers: # existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None) @@ -121,7 +119,6 @@ class DataSourceNotionListApi(Resource): raise NotFound("Dataset not found.") if dataset.data_source_type != "notion_import": raise ValueError("Dataset is not notion type.") - documents = session.scalars( select(Document).filter_by( dataset_id=dataset_id, @@ -182,7 +179,6 @@ class DataSourceNotionApi(Resource): ).scalar_one_or_none() if not data_source_binding: raise NotFound("Data source binding not found.") - extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, @@ -190,7 +186,6 @@ class DataSourceNotionApi(Resource): notion_access_token=data_source_binding.access_token, tenant_id=current_user.current_tenant_id, ) - text_docs = extractor.extract() return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @@ -244,7 +239,6 @@ class DataSourceNotionDatasetSyncApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - documents = DocumentService.get_document_by_dataset_id(dataset_id_str) for document in documents: document_indexing_sync_task.delay(dataset_id_str, document.id) @@ -261,7 +255,6 @@ class DataSourceNotionDocumentSyncApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset_id_str, document_id_str) if document is None: raise NotFound("Document not found.") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1611214cb3..03ab37f044 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -65,17 +65,13 @@ class DatasetListApi(Resource): datasets, total = DatasetService.get_datasets( page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all ) - # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) - model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - data = marshal(datasets, dataset_detail_fields) for item in data: # convert embedding_model_provider to plugin standard format @@ -88,13 +84,11 @@ class DatasetListApi(Resource): item["embedding_available"] = False else: item["embedding_available"] = True - if item.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) item.update({"partial_member_list": part_users_list}) else: item.update({"partial_member_list": []}) - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @@ -147,11 +141,9 @@ class DatasetListApi(Resource): required=False, ) args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, @@ -166,7 +158,6 @@ class DatasetListApi(Resource): ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() - return marshal(dataset, dataset_detail_fields), 201 @@ -191,17 +182,13 @@ class DatasetApi(Resource): if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) - # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) - model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: @@ -210,11 +197,9 @@ class DatasetApi(Resource): data["embedding_available"] = False else: data["embedding_available"] = True - if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) - return data, 200 @setup_required @@ -226,7 +211,6 @@ class DatasetApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() parser.add_argument( "name", @@ -256,7 +240,6 @@ class DatasetApi(Resource): ) parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - parser.add_argument( "external_retrieval_model", type=dict, @@ -265,7 +248,6 @@ class DatasetApi(Resource): location="json", help="Invalid external retrieval model.", ) - parser.add_argument( "external_knowledge_id", type=str, @@ -274,7 +256,6 @@ class DatasetApi(Resource): location="json", help="Invalid external knowledge id.", ) - parser.add_argument( "external_knowledge_api_id", type=str, @@ -285,7 +266,6 @@ class DatasetApi(Resource): ) args = parser.parse_args() data = request.get_json() - # check embedding model setting if ( data.get("indexing_technique") == "high_quality" @@ -295,20 +275,15 @@ class DatasetApi(Resource): DatasetService.check_embedding_model_setting( dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") ) - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) - if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( tenant_id, dataset_id_str, data.get("partial_member_list") @@ -319,10 +294,8 @@ class DatasetApi(Resource): or data.get("permission") == DatasetPermissionEnum.ALL_TEAM ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) result_data.update({"partial_member_list": partial_member_list}) - return result_data, 200 @setup_required @@ -331,11 +304,9 @@ class DatasetApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id_str = str(dataset_id) - # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor or current_user.is_dataset_operator: raise Forbidden() - try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) @@ -352,7 +323,6 @@ class DatasetUseCheckApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id_str = str(dataset_id) - dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) return {"is_using": dataset_is_using}, 200 @@ -366,17 +336,13 @@ class DatasetQueryApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) - response = { "data": marshal(dataset_queries, dataset_query_detail_fields), "has_more": len(dataset_queries) == limit, @@ -419,10 +385,8 @@ class DatasetIndexingEstimateApi(Resource): .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .all() ) - if file_details is None: raise NotFound("File not found.") - if file_details: for file_detail in file_details: extract_setting = ExtractSetting( @@ -482,7 +446,6 @@ class DatasetIndexingEstimateApi(Resource): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response.model_dump(), 200 @@ -496,20 +459,16 @@ class DatasetRelatedAppListApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - app_dataset_joins = DatasetService.get_related_apps(dataset.id) - related_apps = [] for app_dataset_join in app_dataset_joins: app_model = app_dataset_join.app if app_model: related_apps.append(app_model) - return {"data": related_apps, "total": len(related_apps)}, 200 @@ -585,20 +544,17 @@ class DatasetApiKeyApi(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = ( db.session.query(ApiToken) .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .count() ) - if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", ) - key = ApiToken.generate_api_key(self.token_prefix, 24) api_token = ApiToken() api_token.tenant_id = current_user.current_tenant_id @@ -617,11 +573,9 @@ class DatasetApiDeleteApi(Resource): @account_initialization_required def delete(self, api_key_id): api_key_id = str(api_key_id) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = ( db.session.query(ApiToken) .filter( @@ -631,13 +585,10 @@ class DatasetApiDeleteApi(Resource): ) .first() ) - if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {"result": "success"}, 204 @@ -757,7 +708,6 @@ class DatasetErrorDocs(Resource): if dataset is None: raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 @@ -774,9 +724,7 @@ class DatasetPermissionUserListApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - return { "data": partial_members_list, }, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2fcf3ce7b..ed50a97df4 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -60,37 +60,28 @@ class DocumentResource(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - document = DocumentService.get_document(dataset_id, document_id) - if not document: raise NotFound("Document not found.") - if document.tenant_id != current_user.current_tenant_id: raise Forbidden("No permission.") - return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - documents = DocumentService.get_batch_documents(dataset_id, batch) - if not documents: raise NotFound("Documents not found.") - return documents @@ -100,9 +91,7 @@ class GetProcessRuleApi(Resource): @account_initialization_required def get(self): req_data = request.args - document_id = req_data.get("document_id") - # get default rules mode = DocumentService.DEFAULT_RULES["mode"] rules = DocumentService.DEFAULT_RULES["rules"] @@ -110,17 +99,13 @@ class GetProcessRuleApi(Resource): if document_id: # get the latest process rule document = db.get_or_404(Document, document_id) - dataset = DatasetService.get_dataset(document.dataset_id) - if not dataset: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) @@ -132,7 +117,6 @@ class GetProcessRuleApi(Resource): if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return {"mode": mode, "rules": rules, "limits": limits} @@ -166,31 +150,25 @@ class DatasetDocumentListApi(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) - if search: search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == "hit_count": sub_query = ( db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) .group_by(DocumentSegment.document_id) .subquery() ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), sort_logic(Document.position), @@ -205,7 +183,6 @@ class DatasetDocumentListApi(Resource): desc(Document.created_at), desc(Document.position), ) - paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: @@ -236,7 +213,6 @@ class DatasetDocumentListApi(Resource): "total": paginated_documents.total, "page": page, } - return response @setup_required @@ -247,21 +223,16 @@ class DatasetDocumentListApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) - if not dataset: raise NotFound("Dataset not found.") - # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: raise Forbidden() - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = reqparse.RequestParser() parser.add_argument( "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" @@ -279,24 +250,19 @@ class DatasetDocumentListApi(Resource): ) args = parser.parse_args() knowledge_config = KnowledgeConfig(**args) - if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") - # validate args DocumentService.document_create_args_validate(knowledge_config) - try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) dataset = DatasetService.get_dataset(dataset_id) - except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} @setup_required @@ -310,13 +276,11 @@ class DatasetDocumentListApi(Resource): raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) - try: document_ids = request.args.getlist("document_id") DocumentService.delete_documents(dataset, document_ids) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 @@ -331,7 +295,6 @@ class DatasetInitApi(Resource): # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "indexing_technique", @@ -351,7 +314,6 @@ class DatasetInitApi(Resource): parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() @@ -373,10 +335,8 @@ class DatasetInitApi(Resource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - # validate args DocumentService.document_create_args_validate(knowledge_config) - try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user @@ -387,9 +347,7 @@ class DatasetInitApi(Resource): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - return response @@ -401,36 +359,27 @@ class DocumentIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() - data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( db.session.query(UploadFile) .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .first() ) - # raise error if file not found if not file: raise NotFound("File not found.") - extract_setting = ExtractSetting( datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) - indexing_runner = IndexingRunner() - try: estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, @@ -452,7 +401,6 @@ class DocumentIndexingEstimateApi(DocumentResource): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response, 200 @@ -487,7 +435,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): pages.append(page) notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == "upload_file": file_id = data_source_info["upload_file_id"] file_detail = ( @@ -495,15 +442,12 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .first() ) - if file_detail is None: raise NotFound("File not found.") - extract_setting = ExtractSetting( datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", @@ -530,7 +474,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): document_model=document.doc_form, ) extract_settings.append(extract_setting) - else: raise ValueError("Data source type not support") indexing_runner = IndexingRunner() @@ -608,7 +551,6 @@ class DocumentIndexingStatusApi(DocumentResource): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = ( db.session.query(DocumentSegment) .filter( @@ -623,7 +565,6 @@ class DocumentIndexingStatusApi(DocumentResource): .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .count() ) - # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, @@ -652,11 +593,9 @@ class DocumentDetailApi(DocumentResource): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -727,7 +666,6 @@ class DocumentDetailApi(DocumentResource): "doc_form": document.doc_form, "doc_language": document.doc_language, } - return response, 200 @@ -740,31 +678,25 @@ class DocumentProcessingApi(DocumentResource): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) - # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": if document.indexing_status != "indexing": raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id document.paused_at = datetime.now(UTC).replace(tzinfo=None) document.is_paused = True db.session.commit() - elif action == "resume": if document.indexing_status not in {"paused", "error"}: raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None document.paused_at = None document.is_paused = False db.session.commit() else: raise InvalidActionError() - return {"result": "success"}, 200 @@ -781,14 +713,11 @@ class DocumentDeleteApi(DocumentResource): raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) - document = self.get_document(dataset_id, document_id) - try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 @@ -800,26 +729,19 @@ class DocumentMetadataApi(DocumentResource): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() - doc_type = req_data.get("doc_type") doc_metadata = req_data.get("doc_metadata") - # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: raise Forbidden() - if doc_type is None or doc_metadata is None: raise ValueError("Both doc_type and doc_metadata must be provided.") - if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: raise ValueError("Invalid doc_type.") - if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) - document.doc_metadata = {} if doc_type == "others": document.doc_metadata = doc_metadata @@ -828,11 +750,9 @@ class DocumentMetadataApi(DocumentResource): value = doc_metadata.get(key) if value is not None and isinstance(value, value_type): document.doc_metadata[key] = value - document.doc_type = doc_type document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return {"result": "success", "message": "Document metadata updated."}, 200 @@ -847,19 +767,14 @@ class DocumentStatusApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") - # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: raise Forbidden() - # check user's model setting DatasetService.check_dataset_model_setting(dataset) - # check user's permission DatasetService.check_dataset_permission(dataset, current_user) - document_ids = request.args.getlist("document_id") - try: DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) except services.errors.document.DocumentIndexingError as e: @@ -868,7 +783,6 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError(str(e)) except NotFound as e: raise NotFound(str(e)) - return {"result": "success"}, 200 @@ -881,27 +795,21 @@ class DocumentPauseApi(DocumentResource): """pause document.""" dataset_id = str(dataset_id) document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id) - # 404 if document not found if document is None: raise NotFound("Document Not Exists.") - # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() - try: # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot pause completed document.") - return {"result": "success"}, 204 @@ -918,11 +826,9 @@ class DocumentRecoverApi(DocumentResource): if not dataset: raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) - # 404 if document not found if document is None: raise NotFound("Document Not Exists.") - # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() @@ -931,7 +837,6 @@ class DocumentRecoverApi(DocumentResource): DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Document is not in paused status.") - return {"result": "success"}, 204 @@ -942,7 +847,6 @@ class DocumentRetryApi(DocumentResource): @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): """retry document.""" - parser = reqparse.RequestParser() parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() @@ -954,17 +858,13 @@ class DocumentRetryApi(DocumentResource): for document_id in args["document_ids"]: try: document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - # 404 if document not found if document is None: raise NotFound("Document Not Exists.") - # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() - # 400 if document is completed if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() @@ -974,7 +874,6 @@ class DocumentRetryApi(DocumentResource): continue # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {"result": "success"}, 204 @@ -992,12 +891,10 @@ class DocumentRenameApi(DocumentResource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - try: document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document @@ -1024,7 +921,6 @@ class WebsiteDocumentSyncApi(DocumentResource): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {"result": "success"}, 200 @@ -1048,5 +944,4 @@ api.add_resource(DocumentPauseApi, "/datasets//documents//documents//processing/resume") api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") - api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 48142dbe73..02f2fd544f 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -49,17 +49,13 @@ class DatasetDocumentSegmentListApi(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - document = DocumentService.get_document(dataset_id, document_id) - if not document: raise NotFound("Document not found.") - parser = reqparse.RequestParser() parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("status", type=str, action="append", default=[], location="args") @@ -67,15 +63,12 @@ class DatasetDocumentSegmentListApi(Resource): parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("page", type=int, default=1, location="args") - args = parser.parse_args() - page = args["page"] limit = min(args["limit"], 100) status_list = args["status"] hit_count_gte = args["hit_count_gte"] keyword = args["keyword"] - query = ( select(DocumentSegment) .filter( @@ -84,24 +77,18 @@ class DatasetDocumentSegmentListApi(Resource): ) .order_by(DocumentSegment.position.asc()) ) - if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) - if hit_count_gte is not None: query = query.filter(DocumentSegment.hit_count >= hit_count_gte) - if keyword: query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args["enabled"].lower() != "all": if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) - response = { "data": marshal(segments.items, segment_fields), "limit": limit, @@ -129,7 +116,6 @@ class DatasetDocumentSegmentListApi(Resource): if not document: raise NotFound("Document not found.") segment_ids = request.args.getlist("segment_id") - # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: raise Forbidden() @@ -161,7 +147,6 @@ class DatasetDocumentSegmentApi(Resource): # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: raise Forbidden() - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: @@ -183,7 +168,6 @@ class DatasetDocumentSegmentApi(Resource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment_ids = request.args.getlist("segment_id") - document_indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: @@ -370,13 +354,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource): # check file if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") - try: # Skip the first row df = pd.read_csv(file) @@ -410,7 +392,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource): cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job does not exist.") - return {"job_id": job_id, "job_status": cache_result.decode()}, 200 @@ -502,13 +483,10 @@ class ChildChunkAddApi(Resource): parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("page", type=int, default=1, location="args") - args = parser.parse_args() - page = args["page"] limit = min(args["limit"], 100) keyword = args["keyword"] - child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) return { "data": marshal(child_chunks.items, child_chunk_fields), diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index cf9081e154..9d5d9c0e16 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -29,7 +29,6 @@ class ExternalApiTemplateListApi(Resource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) - external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( page, limit, current_user.current_tenant_id, search ) @@ -62,20 +61,16 @@ class ExternalApiTemplateListApi(Resource): required=True, ) args = parser.parse_args() - ExternalDatasetService.validate_api_list(args["settings"]) - # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - try: external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() - return external_knowledge_api.to_dict(), 201 @@ -88,7 +83,6 @@ class ExternalApiTemplateApi(Resource): external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) if external_knowledge_api is None: raise NotFound("API template not found.") - return external_knowledge_api.to_dict(), 200 @setup_required @@ -96,7 +90,6 @@ class ExternalApiTemplateApi(Resource): @account_initialization_required def patch(self, external_knowledge_api_id): external_knowledge_api_id = str(external_knowledge_api_id) - parser = reqparse.RequestParser() parser.add_argument( "name", @@ -114,14 +107,12 @@ class ExternalApiTemplateApi(Resource): ) args = parser.parse_args() ExternalDatasetService.validate_api_list(args["settings"]) - external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( tenant_id=current_user.current_tenant_id, user_id=current_user.id, external_knowledge_api_id=external_knowledge_api_id, args=args, ) - return external_knowledge_api.to_dict(), 200 @setup_required @@ -129,11 +120,9 @@ class ExternalApiTemplateApi(Resource): @account_initialization_required def delete(self, external_knowledge_api_id): external_knowledge_api_id = str(external_knowledge_api_id) - # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor or current_user.is_dataset_operator: raise Forbidden() - ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) return {"result": "success"}, 204 @@ -144,7 +133,6 @@ class ExternalApiUseCheckApi(Resource): @account_initialization_required def get(self, external_knowledge_api_id): external_knowledge_api_id = str(external_knowledge_api_id) - external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( external_knowledge_api_id ) @@ -159,7 +147,6 @@ class ExternalDatasetCreateApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") @@ -172,13 +159,10 @@ class ExternalDatasetCreateApi(Resource): ) parser.add_argument("description", type=str, required=False, nullable=True, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") - args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - try: dataset = ExternalDatasetService.create_external_dataset( tenant_id=current_user.current_tenant_id, @@ -187,7 +171,6 @@ class ExternalDatasetCreateApi(Resource): ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() - return marshal(dataset, dataset_detail_fields), 201 @@ -200,20 +183,16 @@ class ExternalKnowledgeHitTestingApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = reqparse.RequestParser() parser.add_argument("query", type=str, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") args = parser.parse_args() - HitTestingService.hit_testing_args_check(args) - try: response = HitTestingService.external_retrieve( dataset=dataset, @@ -222,7 +201,6 @@ class ExternalKnowledgeHitTestingApi(Resource): external_retrieval_model=args["external_retrieval_model"], metadata_filtering_conditions=args["metadata_filtering_conditions"], ) - return response except Exception as e: raise InternalServerError(str(e)) @@ -241,7 +219,6 @@ class BedrockRetrievalApi(Resource): ) parser.add_argument("knowledge_id", nullable=False, required=True, type=str) args = parser.parse_args() - # Call the knowledge retrieval service result = ExternalDatasetTestService.knowledge_retrieval( args["retrieval_setting"], args["query"], args["knowledge_id"] diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index fba5d4c0f3..d9cdb9bfd4 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -17,11 +17,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id_str = str(dataset_id) - dataset = self.get_and_validate_dataset(dataset_id_str) args = self.parse_args() self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c076863..fda817a11a 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -30,12 +30,10 @@ class DatasetsHitTestingBase: dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") - try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - return dataset @staticmethod @@ -45,7 +43,6 @@ class DatasetsHitTestingBase: @staticmethod def parse_args(): parser = reqparse.RequestParser() - parser.add_argument("query", type=str, location="json") parser.add_argument("retrieval_model", type=dict, required=False, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index b1a83aa371..9b77e54011 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -26,13 +26,11 @@ class DatasetMetadataCreateApi(Resource): parser.add_argument("name", type=str, required=True, nullable=True, location="json") args = parser.parse_args() metadata_args = MetadataArgs(**args) - dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return metadata, 201 @@ -58,14 +56,12 @@ class DatasetMetadataApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=True, location="json") args = parser.parse_args() - dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return metadata, 200 @@ -80,7 +76,6 @@ class DatasetMetadataApi(Resource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) return {"result": "success"}, 204 @@ -106,7 +101,6 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": MetadataService.enable_built_in_field(dataset) elif action == "disable": @@ -125,14 +119,11 @@ class DocumentMetadataEditApi(Resource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") args = parser.parse_args() metadata_args = MetadataOperationData(**args) - MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index d564a00a76..35c0d46be1 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -30,12 +30,9 @@ from services.errors.audio import ( class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files["file"] - try: response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -75,11 +72,9 @@ class ChatTextApi(InstalledAppResource): parser.add_argument("text", type=str, location="json") parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get("message_id", None) text = args.get("text", None) voice = args.get("voice", None) - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4367da1162..02c507e345 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -39,7 +39,6 @@ class CompletionApi(InstalledAppResource): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, location="json", default="") @@ -47,18 +46,14 @@ class CompletionApi(InstalledAppResource): parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -87,9 +82,7 @@ class CompletionStopApi(InstalledAppResource): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {"result": "success"}, 200 @@ -99,7 +92,6 @@ class ChatApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -108,17 +100,13 @@ class ChatApi(InstalledAppResource): parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -150,7 +138,5 @@ class ChatStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index d7c161cc6d..fa979a647d 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -23,17 +23,14 @@ class ConversationListApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() - pinned = None if "pinned" in args and args["pinned"] is not None: pinned = args["pinned"] == "true" - try: with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( @@ -55,14 +52,12 @@ class ConversationApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) try: ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") WebConversationService.unpin(app_model, conversation_id, current_user) - return {"result": "success"}, 204 @@ -73,14 +68,11 @@ class ConversationRenameApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, location="json") parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() - try: return ConversationService.rename( app_model, conversation_id, current_user, args["name"], args["auto_generate"] @@ -95,14 +87,11 @@ class ConversationPinApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - try: WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} @@ -112,8 +101,6 @@ class ConversationUnPinApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) WebConversationService.unpin(app_model, conversation_id, current_user) - return {"result": "success"} diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9d0c08564e..14470436d6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -30,7 +30,6 @@ class InstalledAppsListApi(Resource): def get(self): app_id = request.args.get("app_id", default=None, type=str) current_tenant_id = current_user.current_tenant_id - if app_id: installed_apps = ( db.session.query(InstalledApp) @@ -39,7 +38,6 @@ class InstalledAppsListApi(Resource): ) else: installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ { @@ -54,7 +52,6 @@ class InstalledAppsListApi(Resource): for installed_app in installed_apps if installed_app.app is not None ] - # filter out apps that user doesn't have access to if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id @@ -75,7 +72,6 @@ class InstalledAppsListApi(Resource): res.append(installed_app) installed_app_list = res logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") - installed_app_list.sort( key=lambda app: ( -app["is_pinned"], @@ -83,7 +79,6 @@ class InstalledAppsListApi(Resource): -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, ) ) - return {"installed_apps": installed_app_list} @login_required @@ -93,30 +88,23 @@ class InstalledAppsListApi(Resource): parser = reqparse.RequestParser() parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") - current_tenant_id = current_user.current_tenant_id app = db.session.query(App).filter(App.id == args["app_id"]).first() - if app is None: raise NotFound("App not found") - if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = ( db.session.query(InstalledApp) .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .first() ) - if installed_app is None: # todo: position recommended_app.install_count += 1 - new_installed_app = InstalledApp( app_id=args["app_id"], tenant_id=current_tenant_id, @@ -126,7 +114,6 @@ class InstalledAppsListApi(Resource): ) db.session.add(new_installed_app) db.session.commit() - return {"message": "App installed successfully"} @@ -139,25 +126,20 @@ class InstalledAppApi(InstalledAppResource): def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") - db.session.delete(installed_app) db.session.commit() - return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): parser = reqparse.RequestParser() parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() - commit_args = False if "is_pinned" in args: installed_app.is_pinned = args["is_pinned"] commit_args = True - if commit_args: db.session.commit() - return {"result": "success", "message": "App info updated successfully"} diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 822777604a..1ff3adc723 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -37,17 +37,14 @@ class MessageListApi(InstalledAppResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - try: return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] @@ -61,14 +58,11 @@ class MessageListApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app - message_id = str(message_id) - parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("content", type=str, location="json") args = parser.parse_args() - try: MessageService.create_feedback( app_model=app_model, @@ -79,7 +73,6 @@ class MessageFeedbackApi(InstalledAppResource): ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} @@ -88,17 +81,13 @@ class MessageMoreLikeThisApi(InstalledAppResource): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - message_id = str(message_id) - parser = reqparse.RequestParser() parser.add_argument( "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" ) args = parser.parse_args() - streaming = args["response_mode"] == "streaming" - try: response = AppGenerateService.generate_more_like_this( app_model=app_model, @@ -133,9 +122,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - message_id = str(message_id) - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE @@ -157,5 +144,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource): except Exception: logging.exception("internal server error.") raise InternalServerError() - return {"data": questions} diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index a1280d91d1..dbc4ebd2cb 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -16,26 +16,20 @@ class AppParameterApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - if app_model is None: raise AppUnavailableError() - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495aa..35cc869075 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -17,7 +17,6 @@ app_fields = { "icon_url": AppIconUrlField, "icon_background": fields.String, } - recommended_app_fields = { "app": fields.Nested(app_fields, attribute="app"), "app_id": fields.String, @@ -29,7 +28,6 @@ recommended_app_fields = { "position": fields.Integer, "is_listed": fields.Boolean, } - recommended_app_list_fields = { "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), "categories": fields.List(fields.String), @@ -45,14 +43,12 @@ class RecommendedAppListApi(Resource): parser = reqparse.RequestParser() parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get("language") and args.get("language") in languages: language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: language_prefix = languages[0] - return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 339e7007a0..23ea749497 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -12,7 +12,6 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService feedback_fields = {"rating": fields.String} - message_fields = { "id": fields.String, "inputs": fields.Raw, @@ -36,42 +35,33 @@ class SavedMessageListApi(InstalledAppResource): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() - try: SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} class SavedMessageApi(InstalledAppResource): def delete(self, installed_app, message_id): app_model = installed_app.app - message_id = str(message_id) - if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, current_user, message_id) - return {"result": "success"}, 204 diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3f625e6609..744035a37a 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -38,17 +38,14 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) - return helper.compact_generate_response(response) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -76,7 +73,5 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {"result": "success"} diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index afbd78bd5b..27a60c859f 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -20,12 +20,9 @@ def installed_app_required(view=None): def decorated(*args, **kwargs): if not kwargs.get("installed_app_id"): raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs["installed_app_id"] - installed_app = ( db.session.query(InstalledApp) .filter( @@ -33,16 +30,12 @@ def installed_app_required(view=None): ) .first() ) - if installed_app is None: raise NotFound("Installed app not found") - if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound("Installed app not found") - return view(installed_app, *args, **kwargs) return decorated @@ -66,7 +59,6 @@ def user_allowed_to_access_app(view=None): ) if not res: raise AppAccessDeniedError() - return view(installed_app, *args, **kwargs) return decorated @@ -78,7 +70,6 @@ def user_allowed_to_access_app(view=None): class InstalledAppResource(Resource): # must be reversed if there are multiple decorators - method_decorators = [ user_allowed_to_access_app, installed_app_required, diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 07a241ef86..a4a3cb4f8e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -19,7 +19,6 @@ class CodeBasedExtensionAPI(Resource): parser = reqparse.RequestParser() parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} @@ -42,14 +41,12 @@ class APIBasedExtensionAPI(Resource): parser.add_argument("api_endpoint", type=str, required=True, location="json") parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, name=args["name"], api_endpoint=args["api_endpoint"], api_key=args["api_key"], ) - return APIBasedExtensionService.save(extension_data) @@ -61,7 +58,6 @@ class APIBasedExtensionDetailAPI(Resource): def get(self, id): api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id - return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) @setup_required @@ -71,21 +67,16 @@ class APIBasedExtensionDetailAPI(Resource): def post(self, id): api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("api_endpoint", type=str, required=True, location="json") parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args["name"] extension_data_from_db.api_endpoint = args["api_endpoint"] - if args["api_key"] != HIDDEN_VALUE: extension_data_from_db.api_key = args["api_key"] - return APIBasedExtensionService.save(extension_data_from_db) @setup_required @@ -94,15 +85,11 @@ class APIBasedExtensionDetailAPI(Resource): def delete(self, id): api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) - APIBasedExtensionService.delete(extension_data_from_db) - return {"result": "success"}, 204 api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") - api.add_resource(APIBasedExtensionAPI, "/api-based-extension") api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 66b6214f82..fd815afa84 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -52,22 +52,16 @@ class FileApi(Resource): file = request.files["file"] source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None - if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - if source == "datasets" and not current_user.is_dataset_editor: raise Forbidden() - if source not in ("datasets", None): source = None - try: upload_file = FileService.upload_file( filename=file.filename, @@ -80,7 +74,6 @@ class FileApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index b19e331d2e..add3ce544c 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -29,15 +29,12 @@ class InitValidateAPI(Resource): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - parser = reqparse.RequestParser() parser.add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] - if input_password != os.environ.get("INIT_PASSWORD"): session["is_init_validated"] = False raise InitValidateFailedError() - session["is_init_validated"] = True return {"result": "success"}, 201 @@ -47,10 +44,8 @@ def get_init_validate_status(): if os.environ.get("INIT_PASSWORD"): if session.get("is_init_validated"): return True - with Session(db.engine) as db_session: return db_session.execute(select(DifySetup)).scalar_one_or_none() - return True diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b8cf019e4f..a9c8a89aa3 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -41,9 +41,7 @@ class RemoteFileUploadApi(Resource): parser = reqparse.RequestParser() parser.add_argument("url", type=str, required=True, help="URL is required") args = parser.parse_args() - url = args["url"] - try: resp = ssrf_proxy.head(url=url) if resp.status_code != httpx.codes.OK: @@ -52,14 +50,10 @@ class RemoteFileUploadApi(Resource): raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") except httpx.RequestError as e: raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") - file_info = helpers.guess_file_info_from_response(resp) - if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): raise FileTooLargeError - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content - try: user = cast(Account, current_user) upload_file = FileService.upload_file( @@ -73,7 +67,6 @@ class RemoteFileUploadApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { "id": upload_file.id, "name": upload_file.name, diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e1f19a87a3..4f83374109 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -27,26 +27,21 @@ class SetupApi(Resource): # is set up if get_setup_status(): raise AlreadySetupError() - # is tenant created tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - if not get_init_validate_status(): raise NotInitValidateError() - parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() - # setup RegisterService.setup( email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) ) - return {"result": "success"}, 201 diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index cb5dedca21..49d71219fa 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -26,7 +26,6 @@ class TagListApi(Resource): tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) - return tags, 200 @setup_required @@ -36,7 +35,6 @@ class TagListApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name @@ -46,9 +44,7 @@ class TagListApi(Resource): ) args = parser.parse_args() tag = TagService.save_tags(args) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - return response, 200 @@ -61,18 +57,14 @@ class TagUpdateDeleteApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) - binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - return response, 200 @setup_required @@ -83,9 +75,7 @@ class TagUpdateDeleteApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - TagService.delete_tag(tag_id) - return 204 @@ -97,7 +87,6 @@ class TagBindingCreateApi(Resource): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." @@ -110,7 +99,6 @@ class TagBindingCreateApi(Resource): ) args = parser.parse_args() TagService.save_tag_binding(args) - return 200 @@ -122,7 +110,6 @@ class TagBindingDeleteApi(Resource): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") @@ -131,7 +118,6 @@ class TagBindingDeleteApi(Resource): ) args = parser.parse_args() TagService.delete_tag_binding(args) - return 200 diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 447cc358f8..514f6c21c0 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -16,7 +16,6 @@ class VersionApi(Resource): parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL - result = { "version": dify_config.project.version, "release_date": "", @@ -27,17 +26,14 @@ class VersionApi(Resource): "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, }, } - if not check_update_url: return result - try: response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) result["version"] = args.get("current_version") return result - content = json.loads(response.content) if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): result["version"] = content["version"] @@ -51,7 +47,6 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: try: latest = version.parse(latest_version) current = version.parse(current_version) - # Compare versions return latest > current except version.InvalidVersion: diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 072e904caf..df07291570 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -17,7 +17,6 @@ def plugin_permission_required( def decorated(*args, **kwargs): user = current_user tenant_id = user.current_tenant_id - with Session(db.engine) as session: permission = ( session.query(TenantPluginPermission) @@ -26,11 +25,9 @@ def plugin_permission_required( ) .first() ) - if not permission: # no permission set, allow access for everyone return view(*args, **kwargs) - if install_required: if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: raise Forbidden() @@ -39,7 +36,6 @@ def plugin_permission_required( raise Forbidden() if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: pass - if debug_required: if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: raise Forbidden() @@ -48,7 +44,6 @@ def plugin_permission_required( raise Forbidden() if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: pass - return view(*args, **kwargs) return decorated diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a9dbf44456..823f783ee4 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -37,23 +37,17 @@ class AccountInitApi(Resource): @login_required def post(self): account = current_user - if account.status == "active": raise AccountAlreadyInitedError() - parser = reqparse.RequestParser() - if dify_config.EDITION == "CLOUD": parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument("interface_language", type=supported_language, required=True, location="json") parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == "CLOUD": if not args["invitation_code"]: raise ValueError("invitation_code is required") - # check invitation code invitation_code = ( db.session.query(InvitationCode) @@ -63,22 +57,18 @@ class AccountInitApi(Resource): ) .first() ) - if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() - return {"result": "success"} @@ -101,13 +91,10 @@ class AccountNameApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - # Validate account name length if len(args["name"]) < 3 or len(args["name"]) > 30: raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args["name"]) - return updated_account @@ -120,9 +107,7 @@ class AccountAvatarApi(Resource): parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) - return updated_account @@ -135,9 +120,7 @@ class AccountInterfaceLanguageApi(Resource): parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) - return updated_account @@ -150,9 +133,7 @@ class AccountInterfaceThemeApi(Resource): parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) - return updated_account @@ -165,13 +146,10 @@ class AccountTimezoneApi(Resource): parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() - # Validate timezone string, e.g. America/New_York, Asia/Shanghai if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) - return updated_account @@ -186,15 +164,12 @@ class AccountPasswordApi(Resource): parser.add_argument("new_password", type=str, required=True, location="json") parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() - try: AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} @@ -205,7 +180,6 @@ class AccountIntegrateApi(Resource): "is_bound": fields.Boolean, "link": fields.String, } - integrate_list_fields = { "data": fields.List(fields.Nested(integrate_fields)), } @@ -216,13 +190,10 @@ class AccountIntegrateApi(Resource): @marshal_with(integrate_list_fields) def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] - integrate_data = [] for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) @@ -246,7 +217,6 @@ class AccountIntegrateApi(Resource): "link": f"{base_url}{oauth_base_path}/{provider}", } ) - return {"data": integrate_data} @@ -256,10 +226,8 @@ class AccountDeleteVerifyApi(Resource): @account_initialization_required def get(self): account = current_user - token, code = AccountService.generate_account_deletion_verification_code(account) AccountService.send_account_deletion_verification_email(account, code) - return {"result": "success", "data": token} @@ -269,17 +237,13 @@ class AccountDeleteApi(Resource): @account_initialization_required def post(self): account = current_user - parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, location="json") parser.add_argument("code", type=str, required=True, location="json") args = parser.parse_args() - if not AccountService.verify_account_deletion_code(args["token"], args["code"]): raise InvalidAccountDeletionCodeError() - AccountService.delete_account(account) - return {"result": "success"} @@ -290,9 +254,7 @@ class AccountDeleteUpdateFeedbackApi(Resource): parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("feedback", type=str, required=True, location="json") args = parser.parse_args() - BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) - return {"result": "success"} @@ -309,7 +271,6 @@ class EducationVerifyApi(Resource): @marshal_with(verify_fields) def get(self): account = current_user - return BillingService.EducationIdentity.verify(account.id, account.email) @@ -325,13 +286,11 @@ class EducationApi(Resource): @cloud_edition_billing_enabled def post(self): account = current_user - parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, location="json") parser.add_argument("institution", type=str, required=True, location="json") parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"]) @setup_required @@ -342,7 +301,6 @@ class EducationApi(Resource): @marshal_with(status_fields) def get(self): account = current_user - return BillingService.EducationIdentity.is_active(account.id) @@ -365,7 +323,6 @@ class EducationAutoCompleteApi(Resource): parser.add_argument("page", type=int, required=False, location="args", default=0) parser.add_argument("limit", type=int, required=False, location="args", default=20) args = parser.parse_args() - return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 88c37767e3..5034c5da76 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -14,10 +14,8 @@ class AgentProviderListApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index eb53dcb16e..3aadae463b 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -18,17 +18,14 @@ class EndpointCreateApi(Resource): user = current_user if not user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True) parser.add_argument("settings", type=dict, required=True) parser.add_argument("name", type=str, required=True) args = parser.parse_args() - plugin_unique_identifier = args["plugin_unique_identifier"] settings = args["settings"] name = args["name"] - try: return { "success": EndpointService.create_endpoint( @@ -49,15 +46,12 @@ class EndpointListApi(Resource): @account_initialization_required def get(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page_size", type=int, required=True, location="args") args = parser.parse_args() - page = args["page"] page_size = args["page_size"] - return jsonable_encoder( { "endpoints": EndpointService.list_endpoints( @@ -76,17 +70,14 @@ class EndpointListForSinglePluginApi(Resource): @account_initialization_required def get(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page_size", type=int, required=True, location="args") parser.add_argument("plugin_id", type=str, required=True, location="args") args = parser.parse_args() - page = args["page"] page_size = args["page_size"] plugin_id = args["plugin_id"] - return jsonable_encoder( { "endpoints": EndpointService.list_endpoints_for_single_plugin( @@ -106,16 +97,12 @@ class EndpointDeleteApi(Resource): @account_initialization_required def post(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() - if not user.is_admin_or_owner: raise Forbidden() - endpoint_id = args["endpoint_id"] - return { "success": EndpointService.delete_endpoint( tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id @@ -129,20 +116,16 @@ class EndpointUpdateApi(Resource): @account_initialization_required def post(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) parser.add_argument("settings", type=dict, required=True) parser.add_argument("name", type=str, required=True) args = parser.parse_args() - endpoint_id = args["endpoint_id"] settings = args["settings"] name = args["name"] - if not user.is_admin_or_owner: raise Forbidden() - return { "success": EndpointService.update_endpoint( tenant_id=user.current_tenant_id, @@ -160,16 +143,12 @@ class EndpointEnableApi(Resource): @account_initialization_required def post(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() - endpoint_id = args["endpoint_id"] - if not user.is_admin_or_owner: raise Forbidden() - return { "success": EndpointService.enable_endpoint( tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id @@ -183,16 +162,12 @@ class EndpointDisableApi(Resource): @account_initialization_required def post(self): user = current_user - parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() - endpoint_id = args["endpoint_id"] - if not user.is_admin_or_owner: raise Forbidden() - return { "success": EndpointService.disable_endpoint( tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index b4eb5e246b..d8c28727ba 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -17,9 +17,7 @@ class LoadBalancingCredentialsValidateApi(Resource): def post(self, provider: str): if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -32,13 +30,10 @@ class LoadBalancingCredentialsValidateApi(Resource): ) parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - # validate model load balancing credentials model_load_balancing_service = ModelLoadBalancingService() - result = True error = "" - try: model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, @@ -50,12 +45,9 @@ class LoadBalancingCredentialsValidateApi(Resource): except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {"result": "success" if result else "error"} - if not result: response["error"] = error - return response @@ -66,9 +58,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): def post(self, provider: str, config_id: str): if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -81,13 +71,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): ) parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - # validate model load balancing config credentials model_load_balancing_service = ModelLoadBalancingService() - result = True error = "" - try: model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, @@ -100,12 +87,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {"result": "success" if result else "error"} - if not result: response["error"] = error - return response @@ -114,7 +98,6 @@ api.add_resource( LoadBalancingCredentialsValidateApi, "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", ) - api.add_resource( LoadBalancingConfigCredentialsValidateApi, "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 48225ac90d..b5190732cb 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -46,22 +46,17 @@ class MemberInviteEmailApi(Resource): parser.add_argument("role", type=str, required=True, default="admin", location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args["emails"] invitee_role = args["role"] interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL - workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members - if not workspace_members.is_available(len(invitee_emails)): raise WorkspaceMembersLimitExceeded() - for invitee_email in invitee_emails: try: token = RegisterService.invite_new_member( @@ -81,7 +76,6 @@ class MemberInviteEmailApi(Resource): ) except Exception as e: invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) - return { "result": "success", "invitation_results": invitation_results, @@ -110,7 +104,6 @@ class MemberCancelInviteApi(Resource): return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 @@ -125,22 +118,17 @@ class MemberUpdateRoleApi(Resource): parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() new_role = args["role"] - if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - member = db.session.get(Account, str(member_id)) if not member: abort(404) - try: assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) - # todo: 403 - return {"result": "success"} diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..10d629a5f6 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -21,7 +21,6 @@ class ModelProviderListApi(Resource): @account_initialization_required def get(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument( "model_type", @@ -32,10 +31,8 @@ class ModelProviderListApi(Resource): location="args", ) args = parser.parse_args() - model_provider_service = ModelProviderService() provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) - return jsonable_encoder({"data": provider_list}) @@ -45,10 +42,8 @@ class ModelProviderCredentialApi(Resource): @account_initialization_required def get(self, provider: str): tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return {"credentials": credentials} @@ -60,14 +55,10 @@ class ModelProviderValidateApi(Resource): parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() - result = True error = "" - try: model_provider_service.provider_credentials_validate( tenant_id=tenant_id, provider=provider, credentials=args["credentials"] @@ -75,12 +66,9 @@ class ModelProviderValidateApi(Resource): except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {"result": "success" if result else "error"} - if not result: response["error"] = error or "Unknown error" - return response @@ -91,20 +79,16 @@ class ModelProviderApi(Resource): def post(self, provider: str): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - model_provider_service = ModelProviderService() - try: model_provider_service.save_provider_credentials( tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {"result": "success"}, 201 @setup_required @@ -113,10 +97,8 @@ class ModelProviderApi(Resource): def delete(self, provider: str): if not current_user.is_admin_or_owner: raise Forbidden() - model_provider_service = ModelProviderService() model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {"result": "success"}, 204 @@ -145,9 +127,7 @@ class PreferredProviderTypeUpdateApi(Resource): def post(self, provider: str): if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument( "preferred_provider_type", @@ -158,12 +138,10 @@ class PreferredProviderTypeUpdateApi(Resource): location="json", ) args = parser.parse_args() - model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {"result": "success"} @@ -185,11 +163,9 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") - api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") - api.add_resource( PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" ) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 37d0f6c764..3eb809bf96 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -29,14 +29,11 @@ class DefaultModelApi(Resource): location="args", ) args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({"data": default_model_entity}) @setup_required @@ -45,25 +42,19 @@ class DefaultModelApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() model_settings = args["model_settings"] for model_setting in model_settings: if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: raise ValueError("invalid model type") - if "provider" not in model_setting: continue - if "model" not in model_setting: raise ValueError("invalid model") - try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, @@ -77,7 +68,6 @@ class DefaultModelApi(Resource): f" model:{model_setting.get('model')}" ) raise ex - return {"result": "success"} @@ -87,10 +77,8 @@ class ModelProviderModelApi(Resource): @account_initialization_required def get(self, provider): tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({"data": models}) @setup_required @@ -99,9 +87,7 @@ class ModelProviderModelApi(Resource): def post(self, provider: str): if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -116,9 +102,7 @@ class ModelProviderModelApi(Resource): parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - model_load_balancing_service = ModelLoadBalancingService() - if ( "load_balancing" in args and args["load_balancing"] @@ -127,7 +111,6 @@ class ModelProviderModelApi(Resource): ): if "configs" not in args["load_balancing"]: raise ValueError("invalid load balancing configs") - # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, @@ -136,7 +119,6 @@ class ModelProviderModelApi(Resource): model_type=args["model_type"], configs=args["load_balancing"]["configs"], ) - # enable load balancing model_load_balancing_service.enable_model_load_balancing( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] @@ -146,10 +128,8 @@ class ModelProviderModelApi(Resource): model_load_balancing_service.disable_model_load_balancing( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() - try: model_provider_service.save_model_credentials( tenant_id=tenant_id, @@ -164,7 +144,6 @@ class ModelProviderModelApi(Resource): f" model: {args.get('model')}, model_type: {args.get('model_type')}" ) raise ValueError(str(ex)) - return {"result": "success"}, 200 @setup_required @@ -173,9 +152,7 @@ class ModelProviderModelApi(Resource): def delete(self, provider: str): if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -187,12 +164,10 @@ class ModelProviderModelApi(Resource): location="json", ) args = parser.parse_args() - model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {"result": "success"}, 204 @@ -202,7 +177,6 @@ class ModelProviderModelCredentialApi(Resource): @account_initialization_required def get(self, provider: str): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") parser.add_argument( @@ -214,17 +188,14 @@ class ModelProviderModelCredentialApi(Resource): location="args", ) args = parser.parse_args() - model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) - model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return { "credentials": credentials, "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, @@ -237,7 +208,6 @@ class ModelProviderModelEnableApi(Resource): @account_initialization_required def patch(self, provider: str): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -249,12 +219,10 @@ class ModelProviderModelEnableApi(Resource): location="json", ) args = parser.parse_args() - model_provider_service = ModelProviderService() model_provider_service.enable_model( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {"result": "success"} @@ -264,7 +232,6 @@ class ModelProviderModelDisableApi(Resource): @account_initialization_required def patch(self, provider: str): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -276,12 +243,10 @@ class ModelProviderModelDisableApi(Resource): location="json", ) args = parser.parse_args() - model_provider_service = ModelProviderService() model_provider_service.disable_model( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {"result": "success"} @@ -291,7 +256,6 @@ class ModelProviderModelValidateApi(Resource): @account_initialization_required def post(self, provider: str): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -304,12 +268,9 @@ class ModelProviderModelValidateApi(Resource): ) parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - model_provider_service = ModelProviderService() - result = True error = "" - try: model_provider_service.model_credentials_validate( tenant_id=tenant_id, @@ -321,12 +282,9 @@ class ModelProviderModelValidateApi(Resource): except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {"result": "success" if result else "error"} - if not result: response["error"] = error or "" - return response @@ -338,14 +296,11 @@ class ModelProviderModelParameterRuleApi(Resource): parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({"data": parameter_rules}) @@ -355,10 +310,8 @@ class ModelProviderAvailableModelApi(Resource): @account_initialization_required def get(self, model_type): tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) - return jsonable_encoder({"data": models}) @@ -379,7 +332,6 @@ api.add_resource( api.add_resource( ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" ) - api.add_resource( ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" ) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index c0a4734828..e7ffff97b9 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -25,7 +25,6 @@ class PluginDebuggingKeyApi(Resource): @plugin_permission_required(debug_required=True) def get(self): tenant_id = current_user.current_tenant_id - try: return { "key": PluginService.get_debugging_key(tenant_id), @@ -50,7 +49,6 @@ class PluginListApi(Resource): plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @@ -62,12 +60,10 @@ class PluginListLatestVersionsApi(Resource): req = reqparse.RequestParser() req.add_argument("plugin_ids", type=list, required=True, location="json") args = req.parse_args() - try: versions = PluginService.list_latest_versions(args["plugin_ids"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder({"versions": versions}) @@ -77,16 +73,13 @@ class PluginListInstallationsFromIdsApi(Resource): @account_initialization_required def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("plugin_ids", type=list, required=True, location="json") args = parser.parse_args() - try: plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder({"plugins": plugins}) @@ -97,12 +90,10 @@ class PluginIconApi(Resource): req.add_argument("tenant_id", type=str, required=True, location="args") req.add_argument("filename", type=str, required=True, location="args") args = req.parse_args() - try: icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"]) except PluginDaemonClientSideError as e: raise ValueError(e) - icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -114,19 +105,15 @@ class PluginUploadFromPkgApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - file = request.files["pkg"] - # check file size if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: raise ValueError("File size exceeds the maximum allowed size") - content = file.read() try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -137,18 +124,15 @@ class PluginUploadFromGithubApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") parser.add_argument("version", type=str, required=True, location="json") parser.add_argument("package", type=str, required=True, location="json") args = parser.parse_args() - try: response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -159,19 +143,15 @@ class PluginUploadFromBundleApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - file = request.files["bundle"] - # check file size if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: raise ValueError("File size exceeds the maximum allowed size") - content = file.read() try: response = PluginService.upload_bundle(tenant_id, content) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -182,21 +162,17 @@ class PluginInstallFromPkgApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") args = parser.parse_args() - # check if all plugin_unique_identifiers are valid string for plugin_unique_identifier in args["plugin_unique_identifiers"]: if not isinstance(plugin_unique_identifier, str): raise ValueError("Invalid plugin unique identifier") - try: response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -207,14 +183,12 @@ class PluginInstallFromGithubApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") parser.add_argument("version", type=str, required=True, location="json") parser.add_argument("package", type=str, required=True, location="json") parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") args = parser.parse_args() - try: response = PluginService.install_from_github( tenant_id, @@ -225,7 +199,6 @@ class PluginInstallFromGithubApi(Resource): ) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -236,21 +209,17 @@ class PluginInstallFromMarketplaceApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") args = parser.parse_args() - # check if all plugin_unique_identifiers are valid string for plugin_unique_identifier in args["plugin_unique_identifiers"]: if not isinstance(plugin_unique_identifier, str): raise ValueError("Invalid plugin unique identifier") - try: response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder(response) @@ -261,11 +230,9 @@ class PluginFetchMarketplacePkgApi(Resource): @plugin_permission_required(install_required=True) def get(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") args = parser.parse_args() - try: return jsonable_encoder( { @@ -286,11 +253,9 @@ class PluginFetchManifestApi(Resource): @plugin_permission_required(install_required=True) def get(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") args = parser.parse_args() - try: return jsonable_encoder( { @@ -310,12 +275,10 @@ class PluginFetchInstallTasksApi(Resource): @plugin_permission_required(install_required=True) def get(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page_size", type=int, required=True, location="args") args = parser.parse_args() - try: return jsonable_encoder( {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])} @@ -331,7 +294,6 @@ class PluginFetchInstallTaskApi(Resource): @plugin_permission_required(install_required=True) def get(self, task_id: str): tenant_id = current_user.current_tenant_id - try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: @@ -345,7 +307,6 @@ class PluginDeleteInstallTaskApi(Resource): @plugin_permission_required(install_required=True) def post(self, task_id: str): tenant_id = current_user.current_tenant_id - try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: @@ -359,7 +320,6 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: @@ -373,7 +333,6 @@ class PluginDeleteInstallTaskItemApi(Resource): @plugin_permission_required(install_required=True) def post(self, task_id: str, identifier: str): tenant_id = current_user.current_tenant_id - try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: @@ -387,12 +346,10 @@ class PluginUpgradeFromMarketplaceApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") args = parser.parse_args() - try: return jsonable_encoder( PluginService.upgrade_plugin_with_marketplace( @@ -410,7 +367,6 @@ class PluginUpgradeFromGithubApi(Resource): @plugin_permission_required(install_required=True) def post(self): tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") @@ -418,7 +374,6 @@ class PluginUpgradeFromGithubApi(Resource): parser.add_argument("version", type=str, required=True, location="json") parser.add_argument("package", type=str, required=True, location="json") args = parser.parse_args() - try: return jsonable_encoder( PluginService.upgrade_plugin_with_github( @@ -443,9 +398,7 @@ class PluginUninstallApi(Resource): req = reqparse.RequestParser() req.add_argument("plugin_installation_id", type=str, required=True, location="json") args = req.parse_args() - tenant_id = current_user.current_tenant_id - try: return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} except PluginDaemonClientSideError as e: @@ -460,17 +413,13 @@ class PluginChangePermissionApi(Resource): user = current_user if not user.is_admin_or_owner: raise Forbidden() - req = reqparse.RequestParser() req.add_argument("install_permission", type=str, required=True, location="json") req.add_argument("debug_permission", type=str, required=True, location="json") args = req.parse_args() - install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) - tenant_id = user.current_tenant_id - return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} @@ -480,7 +429,6 @@ class PluginFetchPermissionApi(Resource): @account_initialization_required def get(self): tenant_id = current_user.current_tenant_id - permission = PluginPermissionService.get_permission(tenant_id) if not permission: return jsonable_encoder( @@ -489,7 +437,6 @@ class PluginFetchPermissionApi(Resource): "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, } ) - return jsonable_encoder( { "install_permission": permission.install_permission, @@ -506,10 +453,8 @@ class PluginFetchDynamicSelectOptionsApi(Resource): # check if the user is admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id user_id = current_user.id - parser = reqparse.RequestParser() parser.add_argument("plugin_id", type=str, required=True, location="args") parser.add_argument("provider", type=str, required=True, location="args") @@ -517,7 +462,6 @@ class PluginFetchDynamicSelectOptionsApi(Resource): parser.add_argument("parameter", type=str, required=True, location="args") parser.add_argument("provider_type", type=str, required=True, location="args") args = parser.parse_args() - try: options = PluginParameterService.get_dynamic_select_options( tenant_id, @@ -530,7 +474,6 @@ class PluginFetchDynamicSelectOptionsApi(Resource): ) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder({"options": options}) @@ -555,8 +498,6 @@ api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks//delete/") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg") - api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") - api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2b1379bfb2..2bbe22e23e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -26,10 +26,8 @@ class ToolProviderListApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - req = reqparse.RequestParser() req.add_argument( "type", @@ -40,7 +38,6 @@ class ToolProviderListApi(Resource): location="args", ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) @@ -50,9 +47,7 @@ class ToolBuiltinProviderListToolsApi(Resource): @account_initialization_required def get(self, provider): user = current_user - tenant_id = user.current_tenant_id - return jsonable_encoder( BuiltinToolManageService.list_builtin_tool_provider_tools( tenant_id, @@ -67,10 +62,8 @@ class ToolBuiltinProviderInfoApi(Resource): @account_initialization_required def get(self, provider): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) @@ -80,13 +73,10 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - return BuiltinToolManageService.delete_builtin_tool_provider( user_id, tenant_id, @@ -100,18 +90,13 @@ class ToolBuiltinProviderUpdateApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - args = parser.parse_args() - with Session(db.engine) as session: result = BuiltinToolManageService.update_builtin_tool_provider( session=session, @@ -130,7 +115,6 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @account_initialization_required def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( tenant_id=tenant_id, provider_name=provider, @@ -151,13 +135,10 @@ class ToolApiProviderAddApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") @@ -167,9 +148,7 @@ class ToolApiProviderAddApi(Resource): parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") - args = parser.parse_args() - return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, @@ -190,16 +169,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, @@ -213,16 +187,11 @@ class ToolApiProviderListToolsApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, @@ -238,13 +207,10 @@ class ToolApiProviderUpdateApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") @@ -255,9 +221,7 @@ class ToolApiProviderUpdateApi(Resource): parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") - args = parser.parse_args() - return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, @@ -279,19 +243,13 @@ class ToolApiProviderDeleteApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, @@ -305,16 +263,11 @@ class ToolApiProviderGetApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, @@ -328,9 +281,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @account_initialization_required def get(self, provider): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) @@ -340,11 +291,8 @@ class ToolApiProviderSchemaApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument("schema", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - return ApiToolManageService.parser_api_schema( schema=args["schema"], ) @@ -356,16 +304,13 @@ class ToolApiProviderPreviousTestApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") parser.add_argument("schema", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, args["provider_name"] or "", @@ -383,13 +328,10 @@ class ToolWorkflowProviderCreateApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") @@ -399,9 +341,7 @@ class ToolWorkflowProviderCreateApi(Resource): reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") - args = reqparser.parse_args() - return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, @@ -422,13 +362,10 @@ class ToolWorkflowProviderUpdateApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") @@ -438,12 +375,9 @@ class ToolWorkflowProviderUpdateApi(Resource): reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") - args = reqparser.parse_args() - if not args["workflow_tool_id"]: raise ValueError("incorrect workflow_tool_id") - return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, @@ -464,18 +398,13 @@ class ToolWorkflowProviderDeleteApi(Resource): @account_initialization_required def post(self): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id - reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - args = reqparser.parse_args() - return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, @@ -489,16 +418,12 @@ class ToolWorkflowProviderGetApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") - args = parser.parse_args() - if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, @@ -513,7 +438,6 @@ class ToolWorkflowProviderGetApi(Resource): ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") - return jsonable_encoder(tool) @@ -523,15 +447,11 @@ class ToolWorkflowProviderListToolApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") - args = parser.parse_args() - return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, @@ -547,10 +467,8 @@ class ToolBuiltinListApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder( [ provider.to_dict() @@ -568,10 +486,8 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder( [ provider.to_dict() @@ -589,10 +505,8 @@ class ToolWorkflowListApi(Resource): @account_initialization_required def get(self): user = current_user - user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder( [ provider.to_dict() @@ -615,7 +529,6 @@ class ToolLabelsApi(Resource): # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") - # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") @@ -629,7 +542,6 @@ api.add_resource( "/workspaces/current/tool-provider/builtin//credentials_schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") - # api tool provider api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") @@ -639,16 +551,13 @@ api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/ap api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") - # workflow tool provider api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") - api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") - api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 19999e7361..7a352e4e6c 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -37,7 +37,6 @@ provider_fields = { "is_valid": fields.Boolean, "token_is_set": fields.Boolean, } - tenant_fields = { "id": fields.String, "name": fields.String, @@ -49,7 +48,6 @@ tenant_fields = { "trial_end_reason": fields.String, "custom_config": fields.Raw(attribute="custom_config"), } - tenants_fields = { "id": fields.String, "name": fields.String, @@ -58,7 +56,6 @@ tenants_fields = { "created_at": TimestampField, "current": fields.Boolean, } - workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} @@ -69,10 +66,8 @@ class TenantListApi(Resource): def get(self): tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] - for tenant in tenants: features = FeatureService.get_features(tenant.id) - # Create a dictionary with tenant attributes tenant_dict = { "id": tenant.id, @@ -82,9 +77,7 @@ class TenantListApi(Resource): "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", "current": tenant.id == current_user.current_tenant_id, } - tenant_dicts.append(tenant_dict) - return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 @@ -96,14 +89,11 @@ class WorkspaceListApi(Resource): parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - stmt = select(Tenant).order_by(Tenant.created_at.desc()) tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False) has_more = False - if tenants.has_next: has_more = True - return { "data": marshal(tenants.items, workspace_fields), "has_more": has_more, @@ -121,9 +111,7 @@ class TenantApi(Resource): def get(self): if request.path == "/info": logging.warning("Deprecated URL /info was used.") - tenant = current_user.current_tenant - if tenant.status == TenantStatus.ARCHIVE: tenants = TenantService.get_join_tenants(current_user) # if there is any tenant, switch to the first one @@ -133,7 +121,6 @@ class TenantApi(Resource): # else, raise Unauthorized else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 @@ -145,17 +132,14 @@ class SwitchWorkspaceApi(Resource): parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() - # check if tenant_id is valid, 403 if not try: TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") - return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -169,19 +153,15 @@ class CustomConfigWorkspaceApi(Resource): parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.get_or_404(Tenant, current_user.current_tenant_id) - custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], "replace_webapp_logo": args["replace_webapp_logo"] if args["replace_webapp_logo"] is not None else tenant.custom_config_dict.get("replace_webapp_logo"), } - tenant.custom_config_dict = custom_config_dict db.session.commit() - return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} @@ -193,21 +173,16 @@ class WebappLogoWorkspaceApi(Resource): def post(self): # get file from request file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - extension = file.filename.split(".")[-1] if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() - try: upload_file = FileService.upload_file( filename=file.filename, @@ -215,12 +190,10 @@ class WebappLogoWorkspaceApi(Resource): mimetype=file.mimetype, user=current_user, ) - except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return {"id": upload_file.id}, 201 @@ -233,11 +206,9 @@ class WorkspaceInfoApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() - return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index ca122772de..4419216f62 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -24,10 +24,8 @@ def account_initialization_required(view): def decorated(*args, **kwargs): # check account initialization account = current_user - if account.status == AccountStatus.UNINITIALIZED: raise AccountNotInitializedError() - return view(*args, **kwargs) return decorated @@ -38,7 +36,6 @@ def only_edition_cloud(view): def decorated(*args, **kwargs): if dify_config.EDITION != "CLOUD": abort(404) - return view(*args, **kwargs) return decorated @@ -49,7 +46,6 @@ def only_edition_enterprise(view): def decorated(*args, **kwargs): if not dify_config.ENTERPRISE_ENABLED: abort(404) - return view(*args, **kwargs) return decorated @@ -60,7 +56,6 @@ def only_edition_self_hosted(view): def decorated(*args, **kwargs): if dify_config.EDITION != "SELF_HOSTED": abort(404) - return view(*args, **kwargs) return decorated @@ -110,7 +105,6 @@ def cloud_edition_billing_resource_check(resource: str): abort(403, "The annotation quota has reached the limit of your subscription.") else: return view(*args, **kwargs) - return view(*args, **kwargs) return decorated @@ -132,7 +126,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): ) else: return view(*args, **kwargs) - return view(*args, **kwargs) return decorated @@ -149,13 +142,9 @@ def cloud_edition_billing_rate_limit_check(resource: str): if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) key = f"rate_limit_{current_user.current_tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: # add ratelimit record rate_limit_log = RateLimitLog( @@ -180,10 +169,8 @@ def cloud_utm_record(view): def decorated(*args, **kwargs): try: features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: utm_info = request.cookies.get("utm_info") - if utm_info: utm_info_dict: dict = json.loads(utm_info) OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) @@ -206,7 +193,6 @@ def setup_required(view): raise NotInitValidateError() elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): raise NotSetupError() - return view(*args, **kwargs) return decorated @@ -218,7 +204,6 @@ def enterprise_license_required(view): settings = FeatureService.get_system_features() if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") - return view(*args, **kwargs) return decorated @@ -230,7 +215,6 @@ def email_password_login_enabled(view): features = FeatureService.get_system_features() if features.enable_email_password_login: return view(*args, **kwargs) - # otherwise, return 403 abort(403) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index d4c3245708..ca53090b25 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -4,6 +4,4 @@ from libs.external_api import ExternalApi bp = Blueprint("files", __name__) api = ExternalApi(bp) - - from . import image_preview, tool_files, upload diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 46c19e1fbb..b29284469f 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -18,14 +18,11 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get("timestamp") nonce = request.args.get("nonce") sign = request.args.get("sign") - if not timestamp or not nonce or not sign: return {"content": "Invalid request."}, 400 - try: generator, mimetype = FileService.get_image_preview( file_id=file_id, @@ -35,25 +32,20 @@ class ImagePreviewApi(Resource): ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return Response(generator, mimetype=mimetype) class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - parser = reqparse.RequestParser() parser.add_argument("timestamp", type=str, required=True, location="args") parser.add_argument("nonce", type=str, required=True, location="args") parser.add_argument("sign", type=str, required=True, location="args") parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") - args = parser.parse_args() - if not args["timestamp"] or not args["nonce"] or not args["sign"]: return {"content": "Invalid request."}, 400 - try: generator, upload_file = FileService.get_file_generator_by_file_id( file_id=file_id, @@ -63,7 +55,6 @@ class FilePreviewApi(Resource): ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - response = Response( generator, mimetype=upload_file.mime_type, @@ -90,27 +81,22 @@ class FilePreviewApi(Resource): encoded_filename = quote(upload_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Type"] = "application/octet-stream" - return response class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) - custom_config = TenantService.get_custom_config(workspace_id) webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None - if not webapp_logo_file_id: raise NotFound("webapp logo is not found") - try: generator, mimetype = FileService.get_public_image_preview( webapp_logo_file_id, ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return Response(generator, mimetype=mimetype) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 1c3430ef4f..05ab8523c4 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -14,31 +14,25 @@ from models import db as global_db class ToolFilePreviewApi(Resource): def get(self, file_id, extension): file_id = str(file_id) - parser = reqparse.RequestParser() - parser.add_argument("timestamp", type=str, required=True, location="args") parser.add_argument("nonce", type=str, required=True, location="args") parser.add_argument("sign", type=str, required=True, location="args") parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") - args = parser.parse_args() if not verify_tool_file_signature( file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"] ): raise Forbidden("Invalid request.") - try: tool_file_manager = ToolFileManager(engine=global_db.engine) stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) - if not stream or not tool_file: raise NotFound("file is not found") except Exception: raise UnsupportedFileTypeError() - response = Response( stream, mimetype=tool_file.mimetype, @@ -50,7 +44,6 @@ class ToolFilePreviewApi(Resource): if args["as_attachment"]: encoded_filename = quote(tool_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" - return response diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index f1a15793c7..c26cea398c 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -21,26 +21,20 @@ class PluginUploadFileApi(Resource): def post(self): # get file from request file = request.files["file"] - timestamp = request.args.get("timestamp") nonce = request.args.get("nonce") sign = request.args.get("sign") tenant_id = request.args.get("tenant_id") if not tenant_id: raise Forbidden("Invalid request.") - user_id = request.args.get("user_id") user = get_user(tenant_id, user_id) - filename = file.filename mimetype = file.mimetype - if not filename or not mimetype: raise Forbidden("Invalid request.") - if not timestamp or not nonce or not sign: raise Forbidden("Invalid request.") - if not verify_plugin_file_signature( filename=filename, mimetype=mimetype, @@ -51,7 +45,6 @@ class PluginUploadFileApi(Resource): sign=sign, ): raise Forbidden("Invalid request.") - try: tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, @@ -61,10 +54,8 @@ class PluginUploadFileApi(Resource): filename=filename, conversation_id=None, ) - extension = guess_extension(tool_file.mimetype) or ".bin" preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) - # Create a dictionary with all the necessary attributes result = { "id": tool_file.id, @@ -80,13 +71,11 @@ class PluginUploadFileApi(Resource): "extension": extension, "preview_url": preview_url, } - return result, 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return tool_file, 201 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d51db4322a..62de76e30d 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -4,7 +4,6 @@ from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) - from . import mail from .plugin import plugin from .workspace import workspace diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index ce3373d65c..0445807f21 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -19,7 +19,6 @@ class EnterpriseMail(Resource): parser.add_argument("body", type=str, required=True) parser.add_argument("substitutions", type=dict, required=False) args = parser.parse_args() - EnterpriseMailService.send_mail(DifyMail(**args)) return {"message": "success"}, 200 diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 327e9ce834..88b76c2c58 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -243,7 +243,6 @@ class PluginInvokeAppApi(Resource): inputs=payload.inputs, files=payload.files, ) - return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response)) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 50408e0929..dd36eeea7b 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,7 +20,6 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: with Session(db.engine) as session: if not user_id: user_id = "DEFAULT-USER" - if user_id == "DEFAULT-USER": user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() if not user_model: @@ -41,7 +40,6 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: raise ValueError("user not found") except Exception: raise ValueError("user not found") - return user_model @@ -53,21 +51,15 @@ def get_user_tenant(view: Optional[Callable] = None): parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() - user_id = kwargs.get("user_id") tenant_id = kwargs.get("tenant_id") - if not tenant_id: raise ValueError("tenant_id is required") - if not user_id: user_id = "DEFAULT-USER" - del kwargs["tenant_id"] del kwargs["user_id"] - try: tenant_model = ( db.session.query(Tenant) @@ -78,18 +70,13 @@ def get_user_tenant(view: Optional[Callable] = None): ) except Exception: raise ValueError("tenant not found") - if not tenant_model: raise ValueError("tenant not found") - kwargs["tenant_model"] = tenant_model - user = get_user(tenant_id, user_id) kwargs["user_model"] = user - current_app.login_manager._update_request_context_with_user(user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore - return view_func(*args, **kwargs) return decorated_view @@ -107,12 +94,10 @@ def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel data = request.get_json() except Exception: raise ValueError("invalid json") - try: payload = payload_type(**data) except Exception as e: raise ValueError(f"invalid payload: {str(e)}") - kwargs["payload"] = payload return view_func(*args, **kwargs) diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 77568b75f1..efdec6265c 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -19,16 +19,12 @@ class EnterpriseWorkspace(Resource): parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = db.session.query(Account).filter_by(email=args["owner_email"]).first() if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) TenantService.create_tenant_member(tenant, account, role="owner") - tenant_was_created.send(tenant) - resp = { "id": tenant.id, "name": tenant.name, @@ -37,7 +33,6 @@ class EnterpriseWorkspace(Resource): "created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None, "updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None, } - return { "message": "enterprise workspace created.", "tenant": resp, @@ -51,11 +46,8 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) - tenant_was_created.send(tenant) - resp = { "id": tenant.id, "name": tenant.name, @@ -66,7 +58,6 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): "created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None, "updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None, } - return { "message": "enterprise workspace created.", "tenant": resp, diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index f3a9312dd0..1003ed198d 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -15,12 +15,10 @@ def enterprise_inner_api_only(view): def decorated(*args, **kwargs): if not dify_config.INNER_API: abort(404) - # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) - return view(*args, **kwargs) return decorated @@ -31,32 +29,23 @@ def enterprise_inner_api_user_auth(view): def decorated(*args, **kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) - # get header 'X-Inner-Api-Key' authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) - user_id, token = parts if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key", "") - data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() - return view(*args, **kwargs) return decorated @@ -67,12 +56,10 @@ def plugin_inner_api_only(view): def decorated(*args, **kwargs): if not dify_config.PLUGIN_DAEMON_KEY: abort(404) - # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: abort(404) - return view(*args, **kwargs) return decorated diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d964e27819..540b04cd23 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -4,7 +4,6 @@ from libs.external_api import ExternalApi bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) - from . import index from .app import annotation, app, audio, completion, conversation, file, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 595ae118ef..d167149fe2 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -38,13 +38,11 @@ class AnnotationReplyActionStatusApi(Resource): cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") - job_status = cache_result.decode() error_msg = "" if job_status == "error": app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 @@ -54,7 +52,6 @@ class AnnotationListApi(Resource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) - annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) response = { "data": marshal(annotation_list, annotation_fields), @@ -82,7 +79,6 @@ class AnnotationUpdateDeleteApi(Resource): def put(self, app_model: App, annotation_id): if not current_user.is_editor: raise Forbidden() - annotation_id = str(annotation_id) parser = reqparse.RequestParser() parser.add_argument("question", required=True, type=str, location="json") @@ -95,7 +91,6 @@ class AnnotationUpdateDeleteApi(Resource): def delete(self, app_model: App, annotation_id): if not current_user.is_editor: raise Forbidden() - annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) return {"result": "success"}, 204 diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 89222d5e83..524209df3c 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -20,18 +20,14 @@ class AppParameterApi(Resource): workflow = app_model.workflow if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 848863cf1b..4904d85ed3 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -34,10 +34,8 @@ class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): file = request.files["file"] - try: response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -75,14 +73,12 @@ class TextApi(Resource): parser.add_argument("text", type=str, location="json") parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get("message_id", None) text = args.get("text", None) voice = args.get("voice", None) response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 1d9890199d..f922547eb4 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -36,20 +36,15 @@ class CompletionApi(Resource): def post(self, app_model: App, end_user: EndUser): if app_model.mode != "completion": raise AppUnavailableError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, location="json", default="") parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - args = parser.parse_args() - streaming = args["response_mode"] == "streaming" - args["auto_generate_name"] = False - try: response = AppGenerateService.generate( app_model=app_model, @@ -58,7 +53,6 @@ class CompletionApi(Resource): invoke_from=InvokeFrom.SERVICE_API, streaming=streaming, ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -87,9 +81,7 @@ class CompletionStopApi(Resource): def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != "completion": raise AppUnavailableError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {"result": "success"}, 200 @@ -99,7 +91,6 @@ class ChatApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -108,16 +99,12 @@ class ChatApi(Resource): parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - args = parser.parse_args() - streaming = args["response_mode"] == "streaming" - try: response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -149,9 +136,7 @@ class ChatStopApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {"result": "success"}, 200 diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 36a7905572..df7ae7e5df 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -29,7 +29,6 @@ class ConversationApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") @@ -42,7 +41,6 @@ class ConversationApi(Resource): location="args", ) args = parser.parse_args() - try: with Session(db.engine) as session: return ConversationService.pagination_by_last_id( @@ -65,9 +63,7 @@ class ConversationDetailApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - try: ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: @@ -82,14 +78,11 @@ class ConversationRenameApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, location="json") parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() - try: return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: @@ -104,14 +97,11 @@ class ConversationVariablesApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - try: return ConversationService.get_conversational_variable( app_model, conversation_id, end_user, args["limit"], args["last_id"] diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef..a8803fd030 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -21,20 +21,15 @@ class FileApi(Resource): @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() - if not file.mimetype: raise UnsupportedFileTypeError() - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - try: upload_file = FileService.upload_file( filename=file.filename, @@ -46,7 +41,6 @@ class FileApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d90fa2081f..736bbacec7 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -39,7 +39,6 @@ class MessageListApi(Resource): "status": fields.String, "error": fields.String, } - message_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, @@ -52,13 +51,11 @@ class MessageListApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - try: return MessageService.pagination_by_first_id( app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] @@ -73,12 +70,10 @@ class MessageFeedbackApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) - parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("content", type=str, location="json") args = parser.parse_args() - try: MessageService.create_feedback( app_model=app_model, @@ -89,7 +84,6 @@ class MessageFeedbackApi(Resource): ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} @@ -112,7 +106,6 @@ class MessageSuggestedApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API @@ -124,7 +117,6 @@ class MessageSuggestedApi(Resource): except Exception: logging.exception("internal server error.") raise InternalServerError() - return {"result": "success", "data": questions} diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index e752dfee30..fabdb8b6e4 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -17,13 +17,10 @@ class AppSiteApi(Resource): def get(self, app_model: App): """Retrieve app site info.""" site = db.session.query(Site).filter(Site.app_id == app_model.id).first() - if not site: raise Forbidden() - if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() - return site diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index efb4acc5fb..427522e549 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -36,7 +36,6 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) - workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, @@ -62,7 +61,6 @@ class WorkflowRunDetailApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: raise NotWorkflowAppError() - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() return workflow_run @@ -76,20 +74,16 @@ class WorkflowRunApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get("response_mode") == "streaming" - try: response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) - return helper.compact_generate_response(response) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -117,9 +111,7 @@ class WorkflowTaskStopApi(Resource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {"result": "success"} @@ -152,14 +144,11 @@ class WorkflowAppLogApi(Resource): parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: args.created_at__before = isoparse(args.created_at__before) - if args.created_at__after: args.created_at__after = isoparse(args.created_at__after) - # get paginate workflow app logs workflow_app_service = WorkflowAppService() with Session(db.engine) as session: @@ -175,7 +164,6 @@ class WorkflowAppLogApi(Resource): created_by_end_user_session_id=args.created_by_end_user_session_id, created_by_account=args.created_by_account, ) - return workflow_app_log_pagination diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index a499719fc3..1fc5b7df9f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -39,27 +39,22 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") include_all = request.args.get("include_all", default="false").lower() == "true" - datasets, total = DatasetService.get_datasets( page, limit, tenant_id, current_user, search, tag_ids, include_all ) # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) - model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - data = marshal(datasets, dataset_detail_fields) for item in data: if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -131,9 +126,7 @@ class DatasetListApi(DatasetApiResource): parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - args = parser.parse_args() - if args.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") @@ -148,7 +141,6 @@ class DatasetListApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, @@ -168,7 +160,6 @@ class DatasetListApi(DatasetApiResource): ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() - return marshal(dataset, dataset_detail_fields), 200 @@ -188,17 +179,13 @@ class DatasetApi(DatasetApiResource): if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) - # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) - model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: @@ -207,11 +194,9 @@ class DatasetApi(DatasetApiResource): data["embedding_available"] = False else: data["embedding_available"] = True - if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) - return data, 200 @cloud_edition_billing_rate_limit_check("knowledge", "dataset") @@ -220,7 +205,6 @@ class DatasetApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() parser.add_argument( "name", @@ -250,7 +234,6 @@ class DatasetApi(DatasetApiResource): ) parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - parser.add_argument( "external_retrieval_model", type=dict, @@ -259,7 +242,6 @@ class DatasetApi(DatasetApiResource): location="json", help="Invalid external retrieval model.", ) - parser.add_argument( "external_knowledge_id", type=str, @@ -268,7 +250,6 @@ class DatasetApi(DatasetApiResource): location="json", help="Invalid external knowledge id.", ) - parser.add_argument( "external_knowledge_api_id", type=str, @@ -279,7 +260,6 @@ class DatasetApi(DatasetApiResource): ) args = parser.parse_args() data = request.get_json() - # check embedding model setting if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( @@ -295,20 +275,15 @@ class DatasetApi(DatasetApiResource): data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) - if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( tenant_id, dataset_id_str, data.get("partial_member_list") @@ -319,32 +294,25 @@ class DatasetApi(DatasetApiResource): or data.get("permission") == DatasetPermissionEnum.ALL_TEAM ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) result_data.update({"partial_member_list": partial_member_list}) - return result_data, 200 @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ Deletes a dataset given its ID. - Args: _: ignore dataset_id (UUID): The ID of the dataset to be deleted. - Returns: dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. - Raises: NotFound: If the dataset with the given ID does not exist. """ - dataset_id_str = str(dataset_id) - try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) @@ -361,16 +329,13 @@ class DocumentStatusApi(DatasetApiResource): def patch(self, tenant_id, dataset_id, action): """ Batch update document status. - Args: tenant_id: tenant id dataset_id: dataset id action: action to perform (enable, disable, archive, un_archive) - Returns: dict: A dictionary with a key 'result' and a value 'success' int: HTTP status code 200 indicating that the operation was successful. - Raises: NotFound: If the dataset with the given ID does not exist. Forbidden: If the user does not have permission. @@ -378,30 +343,24 @@ class DocumentStatusApi(DatasetApiResource): """ dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) - if dataset is None: raise NotFound("Dataset not found.") - # Check user's permission try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - # Check dataset model setting DatasetService.check_dataset_model_setting(dataset) - # Get document IDs from request body data = request.get_json() document_ids = data.get("document_ids", []) - try: DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) except services.errors.document.DocumentIndexingError as e: raise InvalidActionError(str(e)) except ValueError as e: raise InvalidActionError(str(e)) - return {"result": "success"}, 200 @@ -411,7 +370,6 @@ class DatasetTagsApi(DatasetApiResource): def get(self, _, dataset_id): """Get all knowledge type tags.""" tags = TagService.get_tags("knowledge", current_user.current_tenant_id) - return tags, 200 @validate_dataset_token @@ -419,7 +377,6 @@ class DatasetTagsApi(DatasetApiResource): """Add a knowledge type tag.""" if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "name", @@ -428,20 +385,16 @@ class DatasetTagsApi(DatasetApiResource): help="Name must be between 1 to 50 characters.", type=DatasetTagsApi._validate_tag_name, ) - args = parser.parse_args() args["type"] = "knowledge" tag = TagService.save_tags(args) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - return response, 200 @validate_dataset_token def patch(self, _, dataset_id): if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "name", @@ -454,11 +407,8 @@ class DatasetTagsApi(DatasetApiResource): args = parser.parse_args() args["type"] = "knowledge" tag = TagService.update_tags(args, args.get("tag_id")) - binding_count = TagService.get_tag_binding_count(args.get("tag_id")) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - return response, 200 @validate_dataset_token @@ -470,7 +420,6 @@ class DatasetTagsApi(DatasetApiResource): parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) args = parser.parse_args() TagService.delete_tag(args.get("tag_id")) - return 204 @staticmethod @@ -486,7 +435,6 @@ class DatasetTagBindingApi(DatasetApiResource): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument( "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." @@ -494,11 +442,9 @@ class DatasetTagBindingApi(DatasetApiResource): parser.add_argument( "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." ) - args = parser.parse_args() args["type"] = "knowledge" TagService.save_tag_binding(args) - return 204 @@ -508,15 +454,12 @@ class DatasetTagUnbindingApi(DatasetApiResource): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - args = parser.parse_args() args["type"] = "knowledge" TagService.delete_tag_binding(args) - return 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d571b21a0a..fe1d048cdb 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -58,24 +58,18 @@ class DocumentAddByTextApi(DatasetApiResource): parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - args = parser.parse_args() - dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: raise ValueError("Dataset does not exist.") - if not dataset.indexing_technique and not args["indexing_technique"]: raise ValueError("indexing_technique is required.") - text = args.get("text") name = args.get("name") if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") - if args.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") @@ -90,7 +84,6 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", @@ -100,7 +93,6 @@ class DocumentAddByTextApi(DatasetApiResource): knowledge_config = KnowledgeConfig(**args) # validate args DocumentService.document_create_args_validate(knowledge_config) - try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, @@ -112,7 +104,6 @@ class DocumentAddByTextApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -137,10 +128,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: raise ValueError("Dataset does not exist.") - if ( args.get("retrieval_model") and args.get("retrieval_model").get("reranking_model") @@ -151,10 +140,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique - if args["text"]: text = args.get("text") name = args.get("name") @@ -170,7 +157,6 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["original_document_id"] = str(document_id) knowledge_config = KnowledgeConfig(**args) DocumentService.document_create_args_validate(knowledge_config) - try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, @@ -182,7 +168,6 @@ class DocumentUpdateByTextApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -202,23 +187,18 @@ class DocumentAddByFileApi(DatasetApiResource): args["doc_form"] = "text_model" if "doc_language" not in args: args["doc_language"] = "English" - # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: raise ValueError("Dataset does not exist.") - if dataset.provider == "external": raise ValueError("External datasets are not supported.") - indexing_technique = args.get("indexing_technique") or dataset.indexing_technique if not indexing_technique: raise ValueError("indexing_technique is required.") args["indexing_technique"] = indexing_technique - if "embedding_model_provider" in args: DatasetService.check_embedding_model_setting( tenant_id, args["embedding_model_provider"], args["embedding_model"] @@ -233,19 +213,15 @@ class DocumentAddByFileApi(DatasetApiResource): args["retrieval_model"].get("reranking_model").get("reranking_provider_name"), args["retrieval_model"].get("reranking_model").get("reranking_model_name"), ) - # save file info file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( filename=file.filename, content=file.read(), @@ -261,11 +237,9 @@ class DocumentAddByFileApi(DatasetApiResource): # validate args knowledge_config = KnowledgeConfig(**args) DocumentService.document_create_args_validate(knowledge_config) - dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule: raise ValueError("process_rule is required.") - try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, @@ -295,31 +269,23 @@ class DocumentUpdateByFileApi(DatasetApiResource): args["doc_form"] = "text_model" if "doc_language" not in args: args["doc_language"] = "English" - # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: raise ValueError("Dataset does not exist.") - if dataset.provider == "external": raise ValueError("External datasets are not supported.") - # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique - if "file" in request.files: # save file info file = request.files["file"] - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - try: upload_file = FileService.upload_file( filename=file.filename, @@ -339,10 +305,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) DocumentService.document_create_args_validate(knowledge_config) - try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, @@ -365,29 +329,22 @@ class DocumentDeleteApi(DatasetApiResource): document_id = str(document_id) dataset_id = str(dataset_id) tenant_id = str(tenant_id) - # get dataset info dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: raise ValueError("Dataset does not exist.") - document = DocumentService.get_document(dataset.id, document_id) - # 404 if document not found if document is None: raise NotFound("Document Not Exists.") - # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() - try: # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return 204 @@ -401,18 +358,13 @@ class DocumentListApi(DatasetApiResource): dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) - if search: search = f"%{search}%" query = query.filter(Document.name.like(search)) - query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items - response = { "data": marshal(documents, document_fields), "has_more": len(documents) == limit, @@ -420,7 +372,6 @@ class DocumentListApi(DatasetApiResource): "total": paginated_documents.total, "page": page, } - return response @@ -479,21 +430,15 @@ class DocumentDetailApi(DatasetApiResource): def get(self, tenant_id, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) - dataset = self.get_dataset(dataset_id, tenant_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: raise NotFound("Document not found.") - if document.tenant_id != str(tenant_id): raise Forbidden("No permission.") - metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -564,7 +509,6 @@ class DocumentDetailApi(DatasetApiResource): "doc_form": document.doc_form, "doc_language": document.doc_language, } - return response diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 52e9bca5da..d40f0a8108 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -7,11 +7,9 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): dataset_id_str = str(dataset_id) - dataset = self.get_and_validate_dataset(dataset_id_str) args = self.parse_args() self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 1968696ee5..83277538fd 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -21,13 +21,11 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): parser.add_argument("name", type=str, required=True, nullable=True, location="json") args = parser.parse_args() metadata_args = MetadataArgs(**args) - dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return marshal(metadata, dataset_metadata_fields), 201 @@ -45,14 +43,12 @@ class DatasetMetadataServiceApi(DatasetApiResource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=True, location="json") args = parser.parse_args() - dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 @@ -64,7 +60,6 @@ class DatasetMetadataServiceApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) return 204 @@ -83,7 +78,6 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": MetadataService.enable_built_in_field(dataset) elif action == "disable": @@ -99,14 +93,11 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") args = parser.parse_args() metadata_args = MetadataOperationData(**args) - MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 403b7f0a0c..5ced3d9acc 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -113,12 +113,10 @@ class SegmentApi(DatasetApiResource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - parser = reqparse.RequestParser() parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - segments, total = SegmentService.get_segments( document_id=document_id, tenant_id=current_user.current_tenant_id, @@ -127,7 +125,6 @@ class SegmentApi(DatasetApiResource): page=page, limit=limit, ) - response = { "data": marshal(segments, segment_fields), "doc_form": document.doc_form, @@ -136,7 +133,6 @@ class SegmentApi(DatasetApiResource): "limit": limit, "page": page, } - return response, 200 @@ -201,12 +197,10 @@ class DatasetSegmentApi(DatasetApiResource): segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - # validate args parser = reqparse.RequestParser() parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - updated_segment = SegmentService.update_segment( SegmentUpdateArgs(**args["segment"]), segment, document, dataset ) @@ -231,7 +225,6 @@ class DatasetSegmentApi(DatasetApiResource): segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @@ -249,19 +242,16 @@ class ChildChunkApi(DatasetApiResource): dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - # check segment segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - # check embedding model setting if dataset.indexing_technique == "high_quality": try: @@ -278,17 +268,14 @@ class ChildChunkApi(DatasetApiResource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - # validate args parser = reqparse.RequestParser() parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - try: child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) - return {"data": marshal(child_chunk, child_chunk_fields)}, 200 def get(self, tenant_id, dataset_id, document_id, segment_id): @@ -299,31 +286,25 @@ class ChildChunkApi(DatasetApiResource): dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - # check segment segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - parser = reqparse.RequestParser() parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("page", type=int, default=1, location="args") args = parser.parse_args() - page = args["page"] limit = min(args["limit"], 100) keyword = args["keyword"] - child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) - return { "data": marshal(child_chunks.items, child_chunk_fields), "total": child_chunks.total, @@ -346,19 +327,16 @@ class DatasetChildChunkApi(DatasetApiResource): dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - # check segment segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - # check child chunk child_chunk_id = str(child_chunk_id) child_chunk = SegmentService.get_child_chunk_by_id( @@ -366,12 +344,10 @@ class DatasetChildChunkApi(DatasetApiResource): ) if not child_chunk: raise NotFound("Child chunk not found.") - try: SegmentService.delete_child_chunk(child_chunk, dataset) except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") @@ -385,36 +361,30 @@ class DatasetChildChunkApi(DatasetApiResource): dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - # get document document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - # get segment segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - # get child chunk child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id ) if not child_chunk: raise NotFound("Child chunk not found.") - # validate args parser = reqparse.RequestParser() parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - try: child_chunk = SegmentService.update_child_chunk( args.get("content"), child_chunk, segment, document, dataset ) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) - return {"data": marshal(child_chunk, child_chunk_fields)}, 200 diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 6382b63ea9..0ba8dde74e 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -36,7 +36,6 @@ class UploadFileApi(DatasetApiResource): raise NotFound("UploadFile not found.") else: raise ValueError("Upload file id not found in document data source info.") - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) return { "id": upload_file.id, diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 3f18474674..0df294f777 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -11,10 +11,8 @@ class ModelProviderAvailableModelApi(Resource): @validate_dataset_token def get(self, _, model_type): tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) - return jsonable_encoder({"data": models}) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 5b919a68d4..9970cbfdab 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -42,23 +42,18 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio @wraps(view_func) def decorated_view(*args, **kwargs): api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") - if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) .filter(Tenant.id == api_token.tenant_id) @@ -79,9 +74,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio raise Unauthorized("Tenant owner account does not exist.") else: raise Unauthorized("Tenant does not exist.") - kwargs["app_model"] = app_model - if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: user_id = request.args.get("user") @@ -92,20 +85,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio else: # use default-user user_id = None - if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") - if user_id: user_id = str(user_id) - end_user = create_or_update_end_user_for_user_id(app_model, user_id) kwargs["end_user"] = end_user - # Set EndUser as current logged-in user for flask_login.current_user current_app.login_manager._update_request_context_with_user(end_user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore - return view_func(*args, **kwargs) return decorated_view @@ -121,13 +109,11 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) - if features.billing.enabled: members = features.members apps = features.apps vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == "members" and 0 < members.limit <= members.size: raise Forbidden("The number of members has reached the limit of your subscription.") elif resource == "apps" and 0 < apps.limit <= apps.size: @@ -138,7 +124,6 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): raise Forbidden("The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) - return view(*args, **kwargs) return decorated @@ -160,7 +145,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s ) else: return view(*args, **kwargs) - return view(*args, **kwargs) return decorated @@ -173,19 +157,14 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) - if resource == "knowledge": knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) key = f"rate_limit_{api_token.tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: # add ratelimit record rate_limit_log = RateLimitLog( @@ -236,7 +215,6 @@ def validate_dataset_token(view=None): if view: return decorator(view) - # if view is None, it means that the decorator is used without parentheses # use the decorator as a function for method_decorators return decorator @@ -249,13 +227,10 @@ def validate_and_get_api_token(scope: str | None = None): auth_header = request.headers.get("Authorization") if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") - auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = datetime.now(UTC).replace(tzinfo=None) cutoff_time = current_time - timedelta(minutes=1) with Session(db.engine, expire_on_commit=False) as session: @@ -271,7 +246,6 @@ def validate_and_get_api_token(scope: str | None = None): ) result = session.execute(update_stmt) api_token = result.scalar_one_or_none() - if not api_token: stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) api_token = session.scalar(stmt) @@ -279,7 +253,6 @@ def validate_and_get_api_token(scope: str | None = None): raise Unauthorized("Access token is invalid") else: session.commit() - return api_token @@ -289,7 +262,6 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] """ if not user_id: user_id = "DEFAULT-USER" - end_user = ( db.session.query(EndUser) .filter( @@ -300,7 +272,6 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] ) .first() ) - if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, @@ -311,7 +282,6 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] ) db.session.add(end_user) db.session.commit() - return end_user @@ -320,8 +290,6 @@ class DatasetApiResource(Resource): def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() - if not dataset: raise NotFound("Dataset not found.") - return dataset diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 56749a0e25..d737d4cc2e 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -7,14 +7,11 @@ from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) - # Files api.add_resource(FileApi, "/files/upload") - # Remote files api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") - from . import ( app, audio, diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 94a525a75d..468053e921 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -24,18 +24,14 @@ class AppParameterApi(WebApiResource): workflow = app_model.workflow if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) @@ -51,21 +47,16 @@ class AppAccessMode(Resource): parser.add_argument("appId", type=str, required=False, location="args") parser.add_argument("appCode", type=str, required=False, location="args") args = parser.parse_args() - features = FeatureService.get_system_features() if not features.webapp_auth.enabled: return {"accessMode": "public"} - app_id = args.get("appId") if args.get("appCode"): app_code = args["appCode"] app_id = AppService.get_app_id_by_code(app_code) - if not app_id: raise ValueError("appId or appCode must be provided") - res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) - return {"accessMode": res.access_mode} @@ -78,28 +69,22 @@ class AppWebAuthPermission(Resource): raise if " " not in auth_header: raise - auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise - decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") except Exception as e: pass - features = FeatureService.get_system_features() if not features.webapp_auth.enabled: return {"result": True} - parser = reqparse.RequestParser() parser.add_argument("appId", type=str, required=True, location="args") args = parser.parse_args() - app_id = args["appId"] app_code = AppService.get_app_code_by_id(app_id) - res = True if WebAppAuthService.is_app_require_permission_check(app_id=app_id): res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2919ca9af4..9c82e70923 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -32,10 +32,8 @@ from services.errors.audio import ( class AudioApi(WebApiResource): def post(self, app_model: App, end_user): file = request.files["file"] - try: response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -74,14 +72,12 @@ class TextApi(WebApiResource): parser.add_argument("text", type=str, location="json") parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get("message_id", None) text = args.get("text", None) voice = args.get("voice", None) response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) - return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd3b9aa804..8eb78ed8d0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -37,24 +37,19 @@ class CompletionApi(WebApiResource): def post(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, location="json", default="") parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - args = parser.parse_args() - streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - try: response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -82,9 +77,7 @@ class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): if app_model.mode != "completion": raise NotCompletionAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {"result": "success"}, 200 @@ -93,7 +86,6 @@ class ChatApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -102,17 +94,13 @@ class ChatApi(WebApiResource): parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - args = parser.parse_args() - streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - try: response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) - return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -143,9 +131,7 @@ class ChatStopApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {"result": "success"}, 200 diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 98cea3974f..36c5c8e013 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -22,7 +22,6 @@ class ConversationListApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") @@ -36,11 +35,9 @@ class ConversationListApi(WebApiResource): location="args", ) args = parser.parse_args() - pinned = None if "pinned" in args and args["pinned"] is not None: pinned = args["pinned"] == "true" - try: with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( @@ -62,14 +59,12 @@ class ConversationApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) try: ConversationService.delete(app_model, conversation_id, end_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") WebConversationService.unpin(app_model, conversation_id, end_user) - return {"result": "success"}, 204 @@ -79,14 +74,11 @@ class ConversationRenameApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, location="json") parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() - try: return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: @@ -98,14 +90,11 @@ class ConversationPinApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) - try: WebConversationService.pin(app_model, conversation_id, end_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} @@ -114,10 +103,8 @@ class ConversationUnPinApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - conversation_id = str(c_id) WebConversationService.unpin(app_model, conversation_id, end_user) - return {"result": "success"} diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index df06a73a85..71f1266588 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -14,19 +14,14 @@ class FileApi(WebApiResource): def post(self, app_model, end_user): file = request.files["file"] source = request.form.get("source") - if "file" not in request.files: raise NoFileUploadedError() - if len(request.files) > 1: raise TooManyFilesError() - if not file.filename: raise FilenameNotExistsError - if source not in ("datasets", None): source = None - try: upload_file = FileService.upload_file( filename=file.filename, @@ -39,5 +34,4 @@ class FileApi(WebApiResource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 0da8d65efc..f3d306ba26 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -32,16 +32,13 @@ class ForgotPasswordSendEmailApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: language = "en-US" - with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() token = None @@ -49,7 +46,6 @@ class ForgotPasswordSendEmailApi(Resource): raise AccountNotFound() else: token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) - return {"result": "success", "data": token} @@ -63,32 +59,24 @@ class ForgotPasswordCheckApi(Resource): parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - user_email = args["email"] - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): AccountService.add_forgot_password_error_rate_limit(args["email"]) raise EmailCodeError() - # Verified, revoke the first token AccountService.revoke_reset_password_token(args["token"]) - # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( user_email, code=args["code"], additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -103,11 +91,9 @@ class ForgotPasswordResetApi(Resource): parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - # Validate passwords match if args["new_password"] != args["password_confirm"]: raise PasswordMismatchError() - # Validate token and get reset data reset_data = AccountService.get_reset_password_data(args["token"]) if not reset_data: @@ -115,24 +101,18 @@ class ForgotPasswordResetApi(Resource): # Must use token in reset phase if reset_data.get("phase", "") != "reset": raise InvalidTokenError() - # Revoke token to prevent reuse AccountService.revoke_reset_password_token(args["token"]) - # Generate secure salt and hash password salt = secrets.token_bytes(16) password_hashed = hash_password(args["new_password"], salt) - email = reset_data.get("email", "") - with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account: self._update_existing_account(account, password_hashed, salt, session) else: raise AccountNotFound() - return {"result": "success"} def _update_existing_account(self, account, password_hashed, salt, session): diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 01c4f4a262..d8db2a646e 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -23,7 +23,6 @@ class LoginApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() - try: account = WebAppAuthService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError: @@ -32,7 +31,6 @@ class LoginApi(Resource): raise EmailOrPasswordMismatchError() except services.errors.account.AccountNotFoundError: raise AccountNotFound() - token = WebAppAuthService.login(account=account) return {"result": "success", "data": {"access_token": token}} @@ -45,8 +43,6 @@ class LoginApi(Resource): # return {"result": "success"} # flask_login.logout_user() # return {"result": "success"} - - class EmailCodeLoginSendEmailApi(Resource): @setup_required @only_edition_enterprise @@ -55,18 +51,15 @@ class EmailCodeLoginSendEmailApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: language = "en-US" - account = WebAppAuthService.get_user_through_email(args["email"]) if account is None: raise AccountNotFound() else: token = WebAppAuthService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} @@ -79,24 +72,18 @@ class EmailCodeLoginApi(Resource): parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, location="json") args = parser.parse_args() - user_email = args["email"] - token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: raise InvalidEmailError() - if token_data["code"] != args["code"]: raise EmailCodeError() - WebAppAuthService.revoke_email_code_login_token(args["token"]) account = WebAppAuthService.get_user_through_email(user_email) if not account: raise AccountNotFound() - token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": {"access_token": token}} diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index f2e1873601..f3e60d6909 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -50,7 +50,6 @@ class MessageListApi(WebApiResource): "status": fields.String, "error": fields.String, } - message_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, @@ -62,13 +61,11 @@ class MessageListApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - try: return MessageService.pagination_by_first_id( app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] @@ -82,12 +79,10 @@ class MessageListApi(WebApiResource): class MessageFeedbackApi(WebApiResource): def post(self, app_model, end_user, message_id): message_id = str(message_id) - parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("content", type=str, location="json", default=None) args = parser.parse_args() - try: MessageService.create_feedback( app_model=app_model, @@ -98,7 +93,6 @@ class MessageFeedbackApi(WebApiResource): ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} @@ -106,17 +100,13 @@ class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): if app_model.mode != "completion": raise NotCompletionAppError() - message_id = str(message_id) - parser = reqparse.RequestParser() parser.add_argument( "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" ) args = parser.parse_args() - streaming = args["response_mode"] == "streaming" - try: response = AppGenerateService.generate_more_like_this( app_model=app_model, @@ -125,7 +115,6 @@ class MessageMoreLikeThisApi(WebApiResource): invoke_from=InvokeFrom.WEB_APP, streaming=streaming, ) - return helper.compact_generate_response(response) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -151,9 +140,7 @@ class MessageSuggestedQuestionApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() - message_id = str(message_id) - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP @@ -175,7 +162,6 @@ class MessageSuggestedQuestionApi(WebApiResource): except Exception: logging.exception("internal server error.") raise InternalServerError() - return {"data": questions} diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 10c3cdcf0e..ecbbd197a3 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -24,10 +24,8 @@ class PassportResource(Resource): app_code = request.headers.get("X-App-Code") user_id = request.args.get("user_id") web_app_access_token = request.args.get("web_app_access_token") - if app_code is None: raise Unauthorized("X-App-Code header is missing.") - # exchange token for enterprise logined web user enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) if enterprise_user_decoded: @@ -35,12 +33,10 @@ class PassportResource(Resource): return exchange_token_for_existing_web_user( app_code=app_code, enterprise_user_decoded=enterprise_user_decoded ) - if system_features.webapp_auth.enabled: app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() - # get site from db and check if it is normal site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: @@ -49,12 +45,10 @@ class PassportResource(Resource): app_model = db.session.query(App).filter(App.id == site.app_id).first() if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - if user_id: end_user = ( db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() ) - if end_user: pass else: @@ -77,7 +71,6 @@ class PassportResource(Resource): ) db.session.add(end_user) db.session.commit() - payload = { "iss": site.app_id, "sub": "Web API Passport", @@ -85,9 +78,7 @@ class PassportResource(Resource): "app_code": app_code, "end_user_id": end_user.id, } - tk = PassportService().issue(payload) - return { "access_token": tk, } @@ -102,7 +93,6 @@ def decode_enterprise_webapp_user_id(jwt_token: str | None): """ if not jwt_token: return None - decoded = PassportService().verify(jwt_token) source = decoded.get("token_source") if not source or source != "webapp_login_token": @@ -120,24 +110,19 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: user_auth_type = enterprise_user_decoded.get("auth_type") if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() - app_model = db.session.query(App).filter(App.id == site.app_id).first() if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) - if app_auth_type == WebAppAuthType.PUBLIC: return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": raise WebAppAuthRequiredError("Please login as external user.") elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": raise WebAppAuthRequiredError("Please login as internal user.") - end_user = None if end_user_id: end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() @@ -190,7 +175,6 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): end_user = ( db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() ) - if not end_user: end_user = EndUser( tenant_id=app_model.tenant_id, @@ -199,10 +183,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): is_anonymous=True, session_id=generate_session_id(), ) - db.session.add(end_user) db.session.commit() - payload = { "iss": site.app_id, "sub": "Web API Passport", @@ -210,9 +192,7 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): "app_code": site.code, "end_user_id": end_user.id, } - tk = PassportService().issue(payload) - return { "access_token": tk, } diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ae68df6bdc..3ddd48ea09 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -36,9 +36,7 @@ class RemoteFileUploadApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument("url", type=str, required=True, help="URL is required") args = parser.parse_args() - url = args["url"] - try: resp = ssrf_proxy.head(url=url) if resp.status_code != httpx.codes.OK: @@ -47,14 +45,10 @@ class RemoteFileUploadApi(WebApiResource): raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") except httpx.RequestError as e: raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") - file_info = helpers.guess_file_info_from_response(resp) - if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): raise FileTooLargeError - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content - try: upload_file = FileService.upload_file( filename=file_info.filename, @@ -67,7 +61,6 @@ class RemoteFileUploadApi(WebApiResource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError - return { "id": upload_file.id, "name": upload_file.name, diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index d7188ef0b3..de28bcc839 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -11,7 +11,6 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService feedback_fields = {"rating": fields.String} - message_fields = { "id": fields.String, "inputs": fields.Raw, @@ -34,39 +33,31 @@ class SavedMessageListApi(WebApiResource): def get(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() - try: SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, end_user, message_id) - return {"result": "success"}, 204 diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea3..5a3c73cfa4 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -23,7 +23,6 @@ class AppSiteApi(WebApiResource): "user_input_form": fields.Raw(attribute="user_input_form_list"), "pre_prompt": fields.String, } - site_fields = { "title": fields.String, "chat_color_theme": fields.String, @@ -41,7 +40,6 @@ class AppSiteApi(WebApiResource): "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } - app_fields = { "app_id": fields.String, "end_user_id": fields.String, @@ -58,15 +56,11 @@ class AppSiteApi(WebApiResource): """Retrieve app site info.""" # get site site = db.session.query(Site).filter(Site.app_id == app_model.id).first() - if not site: raise Forbidden() - if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() - can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo - return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) @@ -85,7 +79,6 @@ class AppSiteInfo: self.model_config = None self.plan = tenant.plan self.can_replace_logo = can_replace_logo - if can_replace_logo: base_url = dify_config.FILES_URL remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 590fd3f2c7..8b0a6df230 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -37,17 +37,14 @@ class WorkflowRunApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) - return helper.compact_generate_response(response) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -74,9 +71,7 @@ class WorkflowTaskStopApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {"result": "success"} diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 154bddfc5c..044fa8c60b 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -19,7 +19,6 @@ def validate_jwt_token(view=None): @wraps(view) def decorated(*args, **kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated @@ -36,13 +35,10 @@ def decode_jwt_token(): auth_header = request.headers.get("Authorization") if auth_header is None: raise Unauthorized("Authorization header is missing.") - if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) @@ -60,7 +56,6 @@ def decode_jwt_token(): end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: raise NotFound() - # for enterprise webapp auth app_web_auth_enabled = False webapp_settings = None @@ -69,12 +64,10 @@ def decode_jwt_token(): if not webapp_settings: raise NotFound("Web app settings not found.") app_web_auth_enabled = webapp_settings.access_mode != "public" - _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_user_accessibility( decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings ) - return app_model, end_user except Unauthorized as e: if system_features.webapp_auth.enabled: @@ -85,7 +78,6 @@ def decode_jwt_token(): ) if app_web_auth_enabled: raise WebAppAuthRequiredError() - raise Unauthorized(e.description) @@ -96,7 +88,6 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au source = decoded.get("token_source") if not source or source != "webapp": raise WebAppAuthRequiredError() - # Check if authentication is not enforced for web, and if the token source is webapp, # raise an error and redirect to normal passport login if not system_webapp_auth_enabled or not app_web_auth_enabled: @@ -117,14 +108,11 @@ def _validate_user_accessibility( user_id = decoded.get("user_id") if not user_id: raise WebAppAuthRequiredError() - if not webapp_settings: raise WebAppAuthRequiredError("Web app settings not found.") - if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): raise WebAppAuthAccessDeniedError() - auth_type = decoded.get("auth_type") granted_at = decoded.get("granted_at") if not auth_type: diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 6998e4d29a..49fc07d9f0 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -73,7 +73,6 @@ class BaseAgentRunner(AppRunner): self.memory = memory self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.model_instance = model_instance - # init callback self.agent_callback = DifyAgentCallbackHandler() # init dataset tools @@ -103,7 +102,6 @@ class BaseAgentRunner(AppRunner): .count() ) db.session.close() - # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) @@ -121,7 +119,6 @@ class BaseAgentRunner(AppRunner): """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: app_generate_entity.app_config.prompt_template.simple_prompt_template = "" - return app_generate_entity def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: @@ -144,12 +141,10 @@ class BaseAgentRunner(AppRunner): "required": [], }, ) - parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = parameter.type.as_normal_type() if parameter.type in { ToolParameter.ToolParameterType.SYSTEM_FILES, @@ -160,18 +155,14 @@ class BaseAgentRunner(AppRunner): enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } - if len(enum) > 0: message_tool.parameters["properties"][parameter.name]["enum"] = enum - if parameter.required: message_tool.parameters["required"].append(parameter.name) - return message_tool, tool_entity def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: @@ -179,7 +170,6 @@ class BaseAgentRunner(AppRunner): convert dataset retriever tool to prompt message tool """ assert tool.entity.description - prompt_tool = PromptMessageTool( name=tool.entity.identity.name, description=tool.entity.description.llm, @@ -189,19 +179,15 @@ class BaseAgentRunner(AppRunner): "required": [], }, ) - for parameter in tool.get_runtime_parameters(): parameter_type = "string" - prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } - if parameter.required: if parameter.name not in prompt_tool.parameters["required"]: prompt_tool.parameters["required"].append(parameter.name) - return prompt_tool def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: @@ -210,7 +196,6 @@ class BaseAgentRunner(AppRunner): """ tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) @@ -221,7 +206,6 @@ class BaseAgentRunner(AppRunner): tool_instances[tool.tool_name] = tool_entity # save prompt tool prompt_messages_tools.append(prompt_tool) - # convert dataset tools into ModelRuntime Tool format for dataset_tool in self.dataset_tools: prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) @@ -229,7 +213,6 @@ class BaseAgentRunner(AppRunner): prompt_messages_tools.append(prompt_tool) # save tool entity tool_instances[dataset_tool.entity.identity.name] = dataset_tool - return tool_instances, prompt_messages_tools def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: @@ -238,11 +221,9 @@ class BaseAgentRunner(AppRunner): """ # try to get tool runtime parameters tool_runtime_parameters = tool.get_runtime_parameters() - for parameter in tool_runtime_parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = parameter.type.as_normal_type() if parameter.type in { ToolParameter.ToolParameterType.SYSTEM_FILES, @@ -253,19 +234,15 @@ class BaseAgentRunner(AppRunner): enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or "", } - if len(enum) > 0: prompt_tool.parameters["properties"][parameter.name]["enum"] = enum - if parameter.required: if parameter.name not in prompt_tool.parameters["required"]: prompt_tool.parameters["required"].append(parameter.name) - return prompt_tool def create_agent_thought( @@ -300,14 +277,11 @@ class BaseAgentRunner(AppRunner): created_by_role="account", created_by=self.user_id, ) - db.session.add(thought) db.session.commit() db.session.refresh(thought) db.session.close() - self.agent_thought_count += 1 - return thought def save_agent_thought( @@ -331,37 +305,28 @@ class BaseAgentRunner(AppRunner): if not updated_agent_thought: raise ValueError("agent thought not found") agent_thought = updated_agent_thought - if thought: agent_thought.thought += thought - if tool_name: agent_thought.tool = tool_name - if tool_input: if isinstance(tool_input, dict): try: tool_input = json.dumps(tool_input, ensure_ascii=False) except Exception: tool_input = json.dumps(tool_input) - updated_agent_thought.tool_input = tool_input - if observation: if isinstance(observation, dict): try: observation = json.dumps(observation, ensure_ascii=False) except Exception: observation = json.dumps(observation) - updated_agent_thought.observation = observation - if answer: agent_thought.answer = answer - if messages_ids is not None and len(messages_ids) > 0: updated_agent_thought.message_files = json.dumps(messages_ids) - if llm_usage: updated_agent_thought.message_token = llm_usage.prompt_tokens updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit @@ -371,7 +336,6 @@ class BaseAgentRunner(AppRunner): updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price updated_agent_thought.tokens = llm_usage.total_tokens updated_agent_thought.total_price = llm_usage.total_price - # check if tool labels is not empty labels = updated_agent_thought.tool_labels or {} tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] @@ -384,18 +348,14 @@ class BaseAgentRunner(AppRunner): labels[tool] = tool_label.to_dict() else: labels[tool] = {"en_US": tool, "zh_Hans": tool} - updated_agent_thought.tool_labels_str = json.dumps(labels) - if tool_invoke_meta is not None: if isinstance(tool_invoke_meta, dict): try: tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) except Exception: tool_invoke_meta = json.dumps(tool_invoke_meta) - updated_agent_thought.tool_meta_str = tool_invoke_meta - db.session.commit() db.session.close() @@ -408,7 +368,6 @@ class BaseAgentRunner(AppRunner): for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = ( db.session.query(Message) .filter( @@ -417,13 +376,10 @@ class BaseAgentRunner(AppRunner): .order_by(Message.created_at.desc()) .all() ) - messages = list(reversed(extract_thread_messages(messages))) - for message in messages: if message.id == self.message.id: continue - result.append(self.organize_agent_user_prompt(message)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: @@ -441,7 +397,6 @@ class BaseAgentRunner(AppRunner): tool_responses = json.loads(agent_thought.observation) except Exception: tool_responses = dict.fromkeys(tools, agent_thought.observation) - for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) @@ -462,7 +417,6 @@ class BaseAgentRunner(AppRunner): tool_call_id=tool_call_id, ) ) - result.extend( [ AssistantPromptMessage( @@ -477,9 +431,7 @@ class BaseAgentRunner(AppRunner): else: if message.answer: result.append(AssistantPromptMessage(content=message.answer)) - db.session.close() - return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: @@ -490,13 +442,10 @@ class BaseAgentRunner(AppRunner): file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) else: file_extra_config = None - if not file_extra_config: return UserPromptMessage(content=message.query) - image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=self.tenant_id, config=file_extra_config ) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 4979f63432..5c2dab4d01 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -42,33 +42,25 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ Run Cot agent application """ - app_generate_entity = self.application_generate_entity self._repack_app_generate_entity(app_generate_entity) self._init_react_state(query) - trace_manager = app_generate_entity.trace_manager - # check model mode if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: app_generate_entity.model_conf.stop.append("Observation") - app_config = self.app_config assert app_config.agent - # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template or "" self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) - iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() self._prompt_messages_tools = prompt_messages_tools - function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" @@ -86,26 +78,20 @@ class CotAgentRunner(BaseAgentRunner, ABC): llm_usage.total_price += usage.total_price model_instance = self.model_instance - while function_call_state and iteration_step <= max_iteration_steps: # continue to run until there is not any tool call function_call_state = False - if iteration_step == max_iteration_steps: # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids: list[str] = [] - agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) - if iteration_step > 1: self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - # recalc llm max tokens prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) @@ -119,7 +105,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): user=self.user_id, callbacks=[], ) - usage_dict: dict[str, Optional[LLMUsage]] = {} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( @@ -129,13 +114,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): observation="", action=None, ) - # publish agent thought if it's first iteration if iteration_step == 1: self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk @@ -155,18 +138,15 @@ class CotAgentRunner(BaseAgentRunner, ABC): system_fingerprint="", delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - assert scratchpad.thought is not None scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) - # get llm usage if "usage" in usage_dict: if usage_dict["usage"] is not None: increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() - self.save_agent_thought( agent_thought=agent_thought, tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), @@ -178,12 +158,10 @@ class CotAgentRunner(BaseAgentRunner, ABC): messages_ids=[], llm_usage=usage_dict["usage"], ) - if not scratchpad.is_final(): self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - if not scratchpad.action: # failed to extract action, return final answer directly final_answer = "" @@ -210,7 +188,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) scratchpad.observation = tool_invoke_response scratchpad.agent_response = tool_invoke_response - self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, @@ -222,17 +199,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): messages_ids=message_file_ids, llm_usage=usage_dict["usage"], ) - self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - # update prompt tool message for prompt_tool in self._prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - iteration_step += 1 - yield LLMResultChunk( model=model_instance.model, prompt_messages=prompt_messages, @@ -241,7 +214,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): ), system_fingerprint="", ) - # save agent thought self.save_agent_thought( agent_thought=agent_thought, @@ -286,17 +258,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_call_name = action.action_name tool_call_args = action.action_input tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: answer = f"there is not a tool named {tool_call_name}" return answer, ToolInvokeMeta.error_instance(answer) - if isinstance(tool_call_args, str): try: tool_call_args = json.loads(tool_call_args) except json.JSONDecodeError: pass - # invoke tool tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( tool=tool_instance, @@ -308,7 +277,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): agent_tool_callback=self.agent_callback, trace_manager=trace_manager, ) - # publish files for message_file_id in message_files: # publish message file @@ -317,7 +285,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) # add message file ids message_file_ids.append(message_file_id) - return tool_invoke_response, tool_invoke_meta def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: @@ -335,7 +302,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception: continue - return instruction def _init_react_state(self, query) -> None: @@ -366,7 +332,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): message += f"Action: {scratchpad.action_str}\n\n" if scratchpad.observation: message += f"Observation: {scratchpad.observation}\n\n" - return message def _organize_historic_prompt_messages( @@ -378,7 +343,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit | None = None - for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: @@ -411,12 +375,9 @@ class CotAgentRunner(BaseAgentRunner, ABC): result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None - result.append(message) - if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 5ff89bdacb..57e669b65e 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -20,18 +20,15 @@ class CotChatAgentRunner(CotAgentRunner): """ assert self.app_config.agent assert self.app_config.agent.prompt - prompt_entity = self.app_config.agent.prompt if not prompt_entity: raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt - system_prompt = ( first_prompt.replace("{{instruction}}", self._instruction) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) ) - return SystemPromptMessage(content=system_prompt) def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: @@ -41,7 +38,6 @@ class CotChatAgentRunner(CotAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -59,11 +55,9 @@ class CotChatAgentRunner(CotAgentRunner): image_detail_config=image_detail_config, ) ) - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) - return prompt_messages def _organize_prompt_messages(self) -> list[PromptMessage]: @@ -72,7 +66,6 @@ class CotChatAgentRunner(CotAgentRunner): """ # organize system prompt system_message = self._organize_system_prompt() - # organize current assistant messages agent_scratchpad = self._agent_scratchpad if not agent_scratchpad: @@ -91,12 +84,9 @@ class CotChatAgentRunner(CotAgentRunner): assistant_message.content += f"Action: {unit.action_str}\n\n" if unit.observation: assistant_message.content += f"Observation: {unit.observation}\n\n" - assistant_messages = [assistant_message] - # query messages query_messages = self._organize_user_query(self._query, []) - if assistant_messages: # organize historic prompt messages historic_messages = self._organize_historic_prompt_messages( @@ -113,6 +103,5 @@ class CotChatAgentRunner(CotAgentRunner): # organize historic prompt messages historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) messages = [system_message, *historic_messages, *query_messages] - # join all messages return messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3a4d31e047..5b89f97840 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -22,13 +22,11 @@ class CotCompletionAgentRunner(CotAgentRunner): if prompt_entity is None: raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt - system_prompt = ( first_prompt.replace("{{instruction}}", self._instruction) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) ) - return system_prompt def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: @@ -37,7 +35,6 @@ class CotCompletionAgentRunner(CotAgentRunner): """ historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) historic_prompt = "" - for message in historic_prompt_messages: if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" @@ -49,7 +46,6 @@ class CotCompletionAgentRunner(CotAgentRunner): if not isinstance(content, TextPromptMessageContent): continue historic_prompt += content.data - return historic_prompt def _organize_prompt_messages(self) -> list[PromptMessage]: @@ -58,10 +54,8 @@ class CotCompletionAgentRunner(CotAgentRunner): """ # organize system prompt system_prompt = self._organize_instruction_prompt() - # organize historic prompt messages historic_prompt = self._organize_historic_prompt() - # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" @@ -74,15 +68,12 @@ class CotCompletionAgentRunner(CotAgentRunner): assistant_prompt += f"Action: {unit.action_str}\n\n" if unit.observation: assistant_prompt += f"Observation: {unit.observation}\n\n" - # query messages query_prompt = f"Question: {self._query}" - # join all messages prompt = ( system_prompt.replace("{{historic_messages}}", historic_prompt) .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) ) - return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5491689ece..dafed9e4a3 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -37,24 +37,18 @@ class FunctionCallAgentRunner(BaseAgentRunner): """ self.query = query app_generate_entity = self.application_generate_entity - app_config = self.app_config assert app_config is not None, "app_config is required" assert app_config.agent is not None, "app_config.agent is required" - # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() - assert app_config.agent - iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - # continue to run until there is not any tool call function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - # get tracing instance trace_manager = app_generate_entity.trace_manager @@ -71,19 +65,15 @@ class FunctionCallAgentRunner(BaseAgentRunner): llm_usage.total_price += usage.total_price model_instance = self.model_instance - while function_call_state and iteration_step <= max_iteration_steps: function_call_state = False - if iteration_step == max_iteration_steps: # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) - # recalc llm max tokens prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) @@ -97,18 +87,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): user=self.user_id, callbacks=[], ) - tool_calls: list[tuple[str, str, dict[str, Any]]] = [] - # save full response response = "" - # save tool call names and inputs tool_call_names = "" tool_call_inputs = "" - current_llm_usage = None - if isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: @@ -129,18 +114,15 @@ class FunctionCallAgentRunner(BaseAgentRunner): except json.JSONDecodeError: # ensure ascii to avoid encoding error tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): for content in chunk.delta.message.content: response += content.data else: response += str(chunk.delta.message.content) - if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) current_llm_usage = chunk.delta.usage - yield chunk else: result = chunks @@ -156,25 +138,20 @@ class FunctionCallAgentRunner(BaseAgentRunner): except json.JSONDecodeError: # ensure ascii to avoid encoding error tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - if result.usage: increase_usage(llm_usage, result.usage) current_llm_usage = result.usage - if result.message and result.message.content: if isinstance(result.message.content, list): for content in result.message.content: response += content.data else: response += str(result.message.content) - if not result.message.content: result.message.content = "" - self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,7 +162,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): usage=result.usage, ), ) - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: assistant_message.tool_calls = [ @@ -200,9 +176,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ] else: assistant_message.content = response - self._current_thoughts.append(assistant_message) - # save thought self.save_agent_thought( agent_thought=agent_thought, @@ -218,9 +192,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - final_answer += response + "\n" - # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: @@ -255,14 +227,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) # add message file ids message_file_ids.append(message_file_id) - tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, "meta": tool_invoke_meta.to_dict(), } - tool_responses.append(tool_response) if tool_response["tool_response"] is not None: self._current_thoughts.append( @@ -272,7 +242,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): name=tool_call_name, ) ) - if len(tool_responses) > 0: # save agent thought self.save_agent_thought( @@ -293,13 +262,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.queue_manager.publish( QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - # update prompt tool for prompt_tool in prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - iteration_step += 1 - # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -333,7 +299,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: """ Extract tool calls from llm result chunk - Returns: List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] """ @@ -342,7 +307,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): args = {} if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append( ( prompt_message.id, @@ -350,13 +314,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): args, ) ) - return tool_calls def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: """ Extract blocking tool calls from llm result - Returns: List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] """ @@ -365,7 +327,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): args = {} if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append( ( prompt_message.id, @@ -373,7 +334,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): args, ) ) - return tool_calls def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: @@ -384,10 +344,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): return [ SystemPromptMessage(content=prompt_template), ] - if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages or [] def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: @@ -397,7 +355,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -415,11 +372,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): image_detail_config=image_detail_config, ) ) - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) - return prompt_messages def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: @@ -428,7 +383,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): We need to remove the image messages from the prompt messages at the first iteration. """ prompt_messages = deepcopy(prompt_messages) - for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): @@ -442,21 +396,18 @@ class FunctionCallAgentRunner(BaseAgentRunner): for content in prompt_message.content ] ) - return prompt_messages def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query or "", []) - self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, memory=self.memory, ).get_prompt() - prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 7c8f09e6b9..ba3ff7a5ab 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -20,17 +20,14 @@ class CotAgentOutputParser: action = json.loads(action, strict=False) except json.JSONDecodeError: return action or "" - # cohere always returns a list if isinstance(action, list) and len(action) == 1: action = action[0] - for key, value in action.items(): if "input" in key.lower(): action_input = value else: action_name = value - if action_name is not None and action_input is not None: return AgentScratchpadUnit.Action( action_name=action_name, @@ -59,31 +56,25 @@ class CotAgentOutputParser: json_quote_count = 0 in_json = False got_json = False - action_cache = "" action_str = "action:" action_idx = 0 - thought_cache = "" thought_str = "thought:" thought_idx = 0 - last_character = "" - for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage response_content = response.delta.message.content if not isinstance(response_content, str): continue - # stream index = 0 while index < len(response_content): steps = 1 delta = response_content[index : index + steps] yield_delta = False - if not in_json and delta == "`": last_character = delta code_block_cache += delta @@ -98,7 +89,6 @@ class CotAgentOutputParser: last_character = delta code_block_cache += delta code_block_delimiter_count = 0 - if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: if last_character not in {"\n", " ", ""}: @@ -127,7 +117,6 @@ class CotAgentOutputParser: yield action_cache action_cache = "" action_idx = 0 - if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if last_character not in {"\n", " ", ""}: yield_delta = True @@ -155,13 +144,11 @@ class CotAgentOutputParser: yield thought_cache thought_cache = "" thought_idx = 0 - if yield_delta: index += steps last_character = delta yield delta continue - if code_block_delimiter_count == 3: if in_code_block: last_character = delta @@ -173,10 +160,8 @@ class CotAgentOutputParser: else: index += steps continue - in_code_block = not in_code_block code_block_delimiter_count = 0 - if not in_code_block: # handle single json if delta == "{": @@ -198,7 +183,6 @@ class CotAgentOutputParser: if in_json: last_character = delta json_cache += delta - if got_json: got_json = False last_character = delta @@ -206,15 +190,11 @@ class CotAgentOutputParser: json_cache = "" json_quote_count = 0 in_json = False - if not in_code_block and not in_json: last_character = delta yield delta.replace("`", "") - index += steps - if code_block_cache: yield code_block_cache - if json_cache: yield parse_action(json_cache) diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 9c722baa23..e5ea488620 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -41,7 +41,6 @@ class AgentStrategyParameter(PluginParameter): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value @@ -85,7 +84,6 @@ class AgentStrategyEntity(BaseModel): description: I18nObject = Field(..., description="The description of the agent strategy") output_schema: Optional[dict] = None features: Optional[list[AgentFeature]] = None - # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index f5ba2119f4..dc4b50d4e3 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -1,25 +1,17 @@ ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. - {{instruction}} - You have access to the following tools: - {{tools}} - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). Valid "action" values: "Final Answer" or {{tool_names}} - Provide only ONE action per $JSON_BLOB, as shown: - ``` { "action": $TOOL_NAME, "action_input": $ACTION_INPUT } ``` - Follow this format: - Question: input question to answer Thought: consider previous and subsequent steps Action: @@ -36,39 +28,27 @@ Action: "action_input": "Final response to human" } ``` - Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. {{historic_messages}} Question: {{query}} {{agent_scratchpad}} Thought:""" # noqa: E501 - - ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:""" - ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. - {{instruction}} - You have access to the following tools: - {{tools}} - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). Valid "action" values: "Final Answer" or {{tool_names}} - Provide only ONE action per $JSON_BLOB, as shown: - ``` { "action": $TOOL_NAME, "action_input": $ACTION_INPUT } ``` - Follow this format: - Question: input question to answer Thought: consider previous and subsequent steps Action: @@ -85,13 +65,9 @@ Action: "action_input": "Final response to human" } ``` - Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. """ # noqa: E501 - - ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" - REACT_PROMPT_TEMPLATES = { "english": { "chat": { diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 79b074cf95..fbe942dcf5 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -43,10 +43,8 @@ class PluginAgentStrategy(BaseAgentStrategy): Invoke the agent strategy. """ manager = PluginAgentClient() - initialized_params = self.initialize_parameters(params) params = convert_parameters_to_plugin_format(initialized_params) - yield from manager.invoke( tenant_id=self.tenant_id, user_id=user_id, diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ada19ef8ce..5b7224bad4 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -48,7 +48,6 @@ def to_prompt_message_content( raise ValueError("Missing file extension") if f.mime_type is None: raise ValueError("Missing file mime_type") - params = { "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", @@ -57,14 +56,12 @@ def to_prompt_message_content( } if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { FileType.IMAGE: ImagePromptMessageContent, FileType.AUDIO: AudioPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent, FileType.DOCUMENT: DocumentPromptMessageContent, } - try: return prompt_class_map[f.type].model_validate(params) except KeyError: @@ -84,15 +81,11 @@ def download(f: File, /): def _download_file_content(path: str, /): """ Download and return the contents of a file as bytes. - This function loads the file from storage and ensures it's in bytes format. - Args: path (str): The path to the file in storage. - Returns: bytes: The contents of the file as a bytes object. - Raises: ValueError: If the loaded file is not a bytes object. """ @@ -112,7 +105,6 @@ def _get_encoded_string(f: File, /): data = _download_file_content(f._storage_key) case FileTransferMethod.TOOL_FILE: data = _download_file_content(f._storage_key) - encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 2b580cb373..880fbb41b3 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -41,19 +41,16 @@ class CodeLanguage(StrEnum): class CodeExecutor: dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() - code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { CodeLanguage.PYTHON3: Python3TemplateTransformer, CodeLanguage.JINJA2: Jinja2TemplateTransformer, CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer, } - code_language_to_running_language = { CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod @@ -66,16 +63,13 @@ class CodeExecutor: :return: """ url = code_execution_endpoint_url / "v1" / "sandbox" / "run" - headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} - data = { "language": cls.code_language_to_running_language.get(language), "code": code, "preload": preload, "enable_network": True, } - try: response = post( str(url), @@ -103,20 +97,15 @@ class CodeExecutor: " please check if the sandbox service is running." f" ( Error: {str(e)} )" ) - try: response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response_data.get("code")) != 0: raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response_code = CodeExecutionResponse(**response_data) - if response_code.data.error: raise CodeExecutionError(response_code.data.error) - return response_code.data.stdout or "" @classmethod @@ -131,12 +120,9 @@ class CodeExecutor: template_transformer = cls.code_template_transformers.get(language) if not template_transformer: raise CodeExecutionError(f"Unsupported language {language}") - runner, preload = template_transformer.transform_caller(code, inputs) - try: response = cls.execute_code(language, preload, runner) except CodeExecutionError as e: raise e - return template_transformer.transform_response(response) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index 62489cdf29..6d4b551ae9 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -10,13 +10,10 @@ class NodeJsTemplateTransformer(TemplateTransformer): f""" // declare main function {cls._code_placeholder} - // decode and prepare input object var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8')) - // execute main function var output_obj = main(inputs_obj) - // convert output to json and print var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index 54c78cdf92..ac0187708d 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -21,20 +21,15 @@ class Jinja2TemplateTransformer(TemplateTransformer): import jinja2 template = jinja2.Template('''{cls._code_placeholder}''') return template.render(**inputs) - import json from base64 import b64decode - # decode and prepare input dict inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) - # execute main function output = main(**inputs_obj) - # convert output and print result = f'''<>{{output}}<>''' print(result) - """) return runner_script @@ -43,15 +38,11 @@ class Jinja2TemplateTransformer(TemplateTransformer): preload_script = dedent(""" import jinja2 from base64 import b64decode - def _jinja2_preload_(): # prepare jinja2 environment, load template and render before to avoid sandbox issue template = jinja2.Template('{{s}}') template.render(s='a') - if __name__ == '__main__': _jinja2_preload_() - """) - return preload_script diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index 836fd273ae..89613b2ce9 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -9,16 +9,12 @@ class Python3TemplateTransformer(TemplateTransformer): runner_script = dedent(f""" # declare main function {cls._code_placeholder} - import json from base64 import b64decode - # decode and prepare input dict inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8')) - # execute main function output_obj = main(**inputs_obj) - # convert output to json and print output_json = json.dumps(output_obj, indent=4) result = f'''<>{{output_json}}<>''' diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index baa792b5bc..8a594bd84f 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -21,7 +21,6 @@ class TemplateTransformer(ABC): """ runner_script = cls.assemble_runner_script(code, inputs) preload_script = cls.get_preload_script() - return runner_script, preload_script @classmethod diff --git a/api/core/helper/download.py b/api/core/helper/download.py index 96400e8ba5..d20a9860d1 100644 --- a/api/core/helper/download.py +++ b/api/core/helper/download.py @@ -5,7 +5,6 @@ def download_with_size_limit(url, max_download_size: int, **kwargs): response = ssrf_proxy.get(url, follow_redirects=True, **kwargs) if response.status_code == 404: raise ValueError("file not found") - total_size = 0 chunks = [] for chunk in response.iter_bytes(): diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 744fce1cf9..6eea1f6613 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -27,7 +27,6 @@ def decrypt_token(tenant_id: str, token: str): def batch_decrypt_token(tenant_id: str, tokens: list[str]): rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id) - return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens] diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index 65bf4fc1db..58804cbb20 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -21,7 +21,6 @@ def download_plugin_pkg(plugin_unique_identifier: str): def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: if len(plugin_ids) == 0: return [] - url = str(marketplace_api_url / "api/v1/plugins/batch") response = requests.post(url, json={"plugin_ids": plugin_ids}) response.raise_for_status() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 35349210bd..38c181cf52 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -19,7 +19,6 @@ class ProviderCredentialsCache: def get(self) -> Optional[dict]: """ Get cached model provider credentials. - :return: """ cached_provider_credentials = redis_client.get(self.cache_key) @@ -29,7 +28,6 @@ class ProviderCredentialsCache: cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None - return dict(cached_provider_credentials) else: return None @@ -37,7 +35,6 @@ class ProviderCredentialsCache: def set(self, credentials: dict) -> None: """ Cache model provider credentials. - :param credentials: provider credentials :return: """ @@ -46,7 +43,6 @@ class ProviderCredentialsCache: def delete(self) -> None: """ Delete cached model provider credentials. - :return: """ redis_client.delete(self.cache_key) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a324ac2767..dee09a1782 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -27,22 +27,16 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: hosting_openai_config = hosting_configuration.provider_map[openai_provider_name] - if hosting_openai_config.credentials is None: return False - # 2000 text per chunk length = 2000 text_chunks = [text[i : i + length] for i in range(0, len(text), length)] - if len(text_chunks) == 0: return True - text_chunk = secrets.choice(text_chunks) - try: model_provider_factory = ModelProviderFactory(tenant_id) - # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( provider=openai_provider_name, model_type=ModelType.MODERATION @@ -51,11 +45,9 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt moderation_result = model_type_instance.invoke( model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk ) - if moderation_result is True: return True except Exception: logger.exception(f"Fails to check moderation, provider_name: {provider_name}") raise InvokeBadRequestError("Rate limit exceeded, please try again later.") - return False diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 8def6fe4ed..5bcb9ffde7 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -28,7 +28,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") - :return: a dict with name as key and index as value """ position_map = get_position_map(folder_path, file_name=file_name) - return pin_position_map( position_map, pin_list=dify_config.POSITION_TOOL_PINS_LIST, @@ -58,17 +57,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) :return: the sorted position map """ positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) - # Add pins to position map position_map = {name: idx for idx, name in enumerate(pin_list)} - # Add remaining positions to position map start_idx = len(position_map) for name in positions: if name not in position_map: position_map[name] = start_idx start_idx += 1 - return position_map @@ -91,9 +87,7 @@ def is_filtered( return False if not include_set and not exclude_set: return False - name = name_func(data) - if name in exclude_set: # exclude_set is prioritized return True if include_set and name not in include_set: # filter out only if include_set is not empty @@ -116,7 +110,6 @@ def sort_by_position_map( """ if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 11f245812e..1cd0b625e3 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -10,7 +10,6 @@ import httpx from configs import dify_config SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES - HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True try: HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY @@ -23,7 +22,6 @@ try: raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") except NameError: HTTP_REQUEST_NODE_SSL_VERIFY = True - BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -39,7 +37,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - if "timeout" not in kwargs: kwargs["timeout"] = httpx.Timeout( timeout=dify_config.SSRF_DEFAULT_TIME_OUT, @@ -47,12 +44,9 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, ) - if "ssl_verify" not in kwargs: kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY - ssl_verify = kwargs.pop("ssl_verify") - retries = 0 while retries <= max_retries: try: @@ -69,17 +63,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): else: with httpx.Client(verify=ssl_verify) as client: response = client.request(method=method, url=url, **kwargs) - if response.status_code not in STATUS_FORCELIST: return response else: logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") - except httpx.RequestError as e: logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") if max_retries == 0: raise - retries += 1 if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 918b3e9eee..90a2ccf508 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -22,7 +22,6 @@ class ToolParameterCache: def get(self) -> Optional[dict]: """ Get cached model provider credentials. - :return: """ cached_tool_parameter = redis_client.get(self.cache_key) @@ -32,7 +31,6 @@ class ToolParameterCache: cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None - return dict(cached_tool_parameter) else: return None @@ -44,7 +42,6 @@ class ToolParameterCache: def delete(self) -> None: """ Delete cached model provider credentials. - :return: """ redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 2e4a04c579..1281c4ab6c 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -18,7 +18,6 @@ class ToolProviderCredentialsCache: def get(self) -> Optional[dict]: """ Get cached model provider credentials. - :return: """ cached_provider_credentials = redis_client.get(self.cache_key) @@ -28,7 +27,6 @@ class ToolProviderCredentialsCache: cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None - return dict(cached_provider_credentials) else: return None @@ -36,7 +34,6 @@ class ToolProviderCredentialsCache: def set(self, credentials: dict) -> None: """ Cache model provider credentials. - :param credentials: provider credentials :return: """ @@ -45,7 +42,6 @@ class ToolProviderCredentialsCache: def delete(self) -> None: """ Delete cached model provider credentials. - :return: """ redis_client.delete(self.cache_key) diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py index dfb143f4c4..db9a4d0ba0 100644 --- a/api/core/helper/url_signer.py +++ b/api/core/helper/url_signer.py @@ -30,23 +30,19 @@ class UrlSigner: timestamp = str(int(time.time())) nonce = os.urandom(16).hex() sign = cls._sign(sign_key, timestamp, nonce, prefix) - return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) @classmethod def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) - return sign == recalculated_sign @classmethod def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: if not dify_config.SECRET_KEY: raise Exception("SECRET_KEY is not set") - data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return encoded_sign diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 20d98562de..f75fb2197f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -51,14 +51,12 @@ class HostingConfiguration: def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": return - self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai() self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai() self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic() self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() - self.moderation_config = self.init_moderation_config() @staticmethod @@ -70,7 +68,6 @@ class HostingConfiguration: "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE, "base_model_name": "gpt-35-turbo", } - quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( @@ -117,9 +114,7 @@ class HostingConfiguration: ], ) quotas.append(trial_quota) - return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -128,31 +123,24 @@ class HostingConfiguration: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS quotas: list[HostingQuota] = [] - if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) - if dify_config.HOSTED_OPENAI_PAID_ENABLED: paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS") paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) - if len(quotas) > 0: credentials = { "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY, } - if dify_config.HOSTED_OPENAI_API_BASE: credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE - if dify_config.HOSTED_OPENAI_API_ORGANIZATION: credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION - return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -162,26 +150,20 @@ class HostingConfiguration: def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas: list[HostingQuota] = [] - if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) - if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: paid_quota = PaidHostingQuota() quotas.append(paid_quota) - if len(quotas) > 0: credentials = { "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY, } - if dify_config.HOSTED_ANTHROPIC_API_BASE: credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE - return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -192,14 +174,12 @@ class HostingConfiguration: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: quotas: list[HostingQuota] = [FreeHostingQuota()] - return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, quotas=quotas, ) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -210,14 +190,12 @@ class HostingConfiguration: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: quotas: list[HostingQuota] = [FreeHostingQuota()] - return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, quotas=quotas, ) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -228,14 +206,12 @@ class HostingConfiguration: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: quotas: list[HostingQuota] = [FreeHostingQuota()] - return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, quotas=quotas, ) - return HostingProvider( enabled=False, quota_unit=quota_unit, @@ -250,9 +226,7 @@ class HostingConfiguration: if "/" not in provider: provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}" hosted_providers.append(provider) - return HostedModerationConfig(enabled=True, providers=hosted_providers) - return HostedModerationConfig(enabled=False) @staticmethod diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f2fe306179..76024a964f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -52,10 +52,8 @@ class IndexingRunner: try: # get dataset dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() - if not dataset: raise ValueError("no dataset found") - # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) @@ -68,14 +66,12 @@ class IndexingRunner: index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # transform documents = self._transform( index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() ) # save segment self._load_segments(dataset, dataset_document, documents) - # load self._load( index_processor=index_processor, @@ -104,17 +100,14 @@ class IndexingRunner: try: # get dataset dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() - if not dataset: raise ValueError("no dataset found") - # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) .all() ) - for document_segment in document_segments: db.session.delete(document_segment) if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: @@ -129,19 +122,16 @@ class IndexingRunner: ) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # transform documents = self._transform( index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() ) # save segment self._load_segments(dataset, dataset_document, documents) - # load self._load( index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -165,17 +155,14 @@ class IndexingRunner: try: # get dataset dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() - if not dataset: raise ValueError("no dataset found") - # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) .all() ) - documents = [] if document_segments: for document_segment in document_segments: @@ -207,7 +194,6 @@ class IndexingRunner: child_documents.append(child_document) document.children = child_documents documents.append(document) - # build index # get the process rule processing_rule = ( @@ -215,7 +201,6 @@ class IndexingRunner: .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) - index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( @@ -255,7 +240,6 @@ class IndexingRunner: batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - embedding_model_instance = None if dataset_id: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() @@ -281,7 +265,6 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) preview_texts = [] # type: ignore - total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -312,7 +295,6 @@ class IndexingRunner: if document.children: preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore preview_texts.append(preview_detail) - # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: @@ -326,7 +308,6 @@ class IndexingRunner: image_upload_file_is: {}".format(upload_file_id) ) db.session.delete(image_file) - if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore @@ -337,17 +318,14 @@ class IndexingRunner: # load file if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] if dataset_document.data_source_type == "upload_file": if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = ( db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() ) - if file_detail: extract_setting = ExtractSetting( datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form @@ -402,14 +380,12 @@ class IndexingRunner: DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) - # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: if text_doc.metadata is not None: text_doc.metadata["document_id"] = dataset_document.id text_doc.metadata["dataset_id"] = dataset_document.dataset_id - return text_docs @staticmethod @@ -437,10 +413,8 @@ class IndexingRunner: max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - if separator: separator = separator.replace("\\n", "\n") - character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=max_tokens, chunk_overlap=chunk_overlap, @@ -457,7 +431,6 @@ class IndexingRunner: separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) - return character_splitter # type: ignore def _split_to_documents_for_estimate( @@ -471,10 +444,8 @@ class IndexingRunner: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) text_doc.page_content = document_text - # parse document to nodes documents = splitter.split_documents([text_doc]) - split_documents = [] for document in documents: if document.page_content is None or not document.page_content.strip(): @@ -484,11 +455,8 @@ class IndexingRunner: hash = helper.generate_text_hash(document.page_content) document.metadata["doc_id"] = doc_id document.metadata["doc_hash"] = hash - split_documents.append(document) - all_documents.extend(split_documents) - return all_documents @staticmethod @@ -501,14 +469,12 @@ class IndexingRunner: else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} document_text = CleanProcessor.clean(text, {"rules": rules}) - return document_text @staticmethod def format_split_text(text: str) -> list[QAPreviewDetail]: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] def _load( @@ -521,7 +487,6 @@ class IndexingRunner: """ insert index and update document/segment status to completed """ - embedding_model_instance = None if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( @@ -530,7 +495,6 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 @@ -541,12 +505,10 @@ class IndexingRunner: args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() - max_workers = 10 if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] - # Distribute documents into multiple groups based on the hash values of page_content # This is done to prevent multiple threads from processing the same document, # Thereby avoiding potential database insertion deadlocks @@ -569,13 +531,11 @@ class IndexingRunner: embedding_model_instance, ) ) - for future in futures: tokens += future.result() if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": create_keyword_thread.join() indexing_end_at = time.perf_counter() - # update document status to completed self._update_document_index_status( document_id=dataset_document.id, @@ -610,7 +570,6 @@ class IndexingRunner: DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) - db.session.commit() def _process_chunk( @@ -619,15 +578,12 @@ class IndexingRunner: with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) - tokens = 0 if embedding_model_instance: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) - # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, @@ -641,9 +597,7 @@ class IndexingRunner: DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) - db.session.commit() - return tokens @staticmethod @@ -666,12 +620,9 @@ class IndexingRunner: document = db.session.query(DatasetDocument).filter_by(id=document_id).first() if not document: raise DocumentIsDeletedPausedError() - update_params = {DatasetDocument.indexing_status: after_indexing_status} - if extra_update_params: update_params.update(extra_update_params) - db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) db.session.commit() @@ -706,7 +657,6 @@ class IndexingRunner: tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform( text_docs, embedding_model_instance=embedding_model_instance, @@ -714,7 +664,6 @@ class IndexingRunner: tenant_id=dataset.tenant_id, doc_language=doc_language, ) - return documents def _load_segments(self, dataset, dataset_document, documents): @@ -722,10 +671,8 @@ class IndexingRunner: doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) - # add document segments doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) - # update document status to indexing cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( @@ -736,7 +683,6 @@ class IndexingRunner: DatasetDocument.splitting_completed_at: cur_time, }, ) - # update segment status to indexing self._update_segments_by_document( dataset_document_id=dataset_document.id, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e01896a491..5356e2d806 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -32,21 +32,16 @@ class LLMGenerator: cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None ): prompt = CONVERSATION_TITLE_PROMPT - if len(query) > 2000: query = query[:300] + "...[TRUNCATED]..." + query[-300:] - query = query.replace("\n", " ") - prompt += query + "\n" - model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, ) prompts = [UserPromptMessage(content=prompt)] - with measure_time() as timer: response = cast( LLMResult, @@ -65,10 +60,8 @@ class LLMGenerator: logging.exception("Failed to generate name after answer, use query instead") answer = query name = answer.strip() - if len(name) > 75: name = name[:75] + "..." - # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) trace_manager.add_trace_task( @@ -81,18 +74,14 @@ class LLMGenerator: tenant_id=tenant_id, ) ) - return name @classmethod def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) - try: model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( @@ -101,9 +90,7 @@ class LLMGenerator: ) except InvokeAuthorizationError: return [] - prompt_messages = [UserPromptMessage(content=prompt)] - try: response = cast( LLMResult, @@ -113,14 +100,12 @@ class LLMGenerator: stream=False, ), ) - questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception: logging.exception("Failed to generate suggested questions after answer") questions = [] - return questions @classmethod @@ -128,31 +113,24 @@ class LLMGenerator: cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 ) -> dict: output_parser = RuleConfigGeneratorOutputParser() - error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} - if no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) - prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, remove_template_variables=False, ) - prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, ) - try: response = cast( LLMResult, @@ -160,29 +138,20 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) - rule_config["prompt"] = cast(str, response.message.content) - except InvokeError as e: error = str(e) error_step = "generate rule config" except Exception as e: logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") rule_config["error"] = str(e) - rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" - return rule_config - # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser(statement_generate) - # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ @@ -191,7 +160,6 @@ class LLMGenerator: remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] - # get model instance model_manager = ModelManager() model_instance = model_manager.get_model_instance( @@ -200,7 +168,6 @@ class LLMGenerator: provider=model_config.get("provider", ""), model=model_config.get("name", ""), ) - try: try: # the first step to generate the task prompt @@ -214,11 +181,8 @@ class LLMGenerator: error = str(e) error_step = "generate prefix prompt" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" - return rule_config - rule_config["prompt"] = cast(str, prompt_content.message.content) - if not isinstance(prompt_content.message.content, str): raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( @@ -228,7 +192,6 @@ class LLMGenerator: remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] - # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( inputs={ @@ -238,7 +201,6 @@ class LLMGenerator: remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] - try: parameter_content = cast( LLMResult, @@ -250,7 +212,6 @@ class LLMGenerator: except InvokeError as e: error = str(e) error_step = "generate variables" - try: statement_content = cast( LLMResult, @@ -262,13 +223,10 @@ class LLMGenerator: except InvokeError as e: error = str(e) error_step = "generate conversation opener" - except Exception as e: logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") rule_config["error"] = str(e) - rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" - return rule_config @classmethod @@ -284,7 +242,6 @@ class LLMGenerator: prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) - prompt = prompt_template.format( inputs={ "INSTRUCTION": instruction, @@ -292,7 +249,6 @@ class LLMGenerator: }, remove_template_variables=False, ) - model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -300,10 +256,8 @@ class LLMGenerator: provider=model_config.get("provider", ""), model=model_config.get("name", ""), ) - prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} - try: response = cast( LLMResult, @@ -311,10 +265,8 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) - generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} - except InvokeError as e: error = str(e) return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} @@ -327,15 +279,12 @@ class LLMGenerator: @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, ) - prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = cast( LLMResult, model_instance.invoke_llm( @@ -344,7 +293,6 @@ class LLMGenerator: stream=False, ), ) - answer = cast(str, response.message.content) return answer.strip() @@ -357,13 +305,11 @@ class LLMGenerator: provider=model_config.get("provider", ""), model=model_config.get("name", ""), ) - prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), UserPromptMessage(content=instruction), ] model_parameters = model_config.get("model_parameters", {}) - try: response = cast( LLMResult, @@ -371,23 +317,17 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) - raw_content = response.message.content - if not isinstance(raw_content, str): raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}") - try: parsed_content = json.loads(raw_content) except json.JSONDecodeError: parsed_content = json_repair.loads(raw_content) - if not isinstance(parsed_content, dict | list): raise ValueError(f"Failed to parse structured output from llm: {raw_content}") - generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False) return {"output": generated_json_schema, "error": ""} - except InvokeError as e: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 151cef1bc3..0412bd1fb4 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -57,8 +57,6 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - @overload def invoke_llm_with_structured_output( provider: str, @@ -73,8 +71,6 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput: ... - - @overload def invoke_llm_with_structured_output( provider: str, @@ -89,8 +85,6 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( provider: str, model_schema: AIModelEntity, @@ -108,7 +102,6 @@ def invoke_llm_with_structured_output( Invoke large language model with structured output 1. This method invokes model_instance.invoke_llm with json_schema 2. Try to parse the result as structured output - :param prompt_messages: prompt messages :param json_schema: json schema :param model_parameters: model parameters @@ -119,12 +112,10 @@ def invoke_llm_with_structured_output( :param callbacks: callbacks :return: full response or stream response chunk generator result """ - # handle native json schema model_parameters_with_json_schema: dict[str, Any] = { **(model_parameters or {}), } - if model_schema.support_structure_output: model_parameters = _handle_native_json_schema( provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules @@ -132,13 +123,11 @@ def invoke_llm_with_structured_output( else: # Set appropriate response format based on model capabilities _set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules) - # handle prompt based schema prompt_messages = _handle_prompt_based_schema( prompt_messages=prompt_messages, structured_output_schema=json_schema, ) - llm_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=model_parameters_with_json_schema, @@ -148,13 +137,11 @@ def invoke_llm_with_structured_output( user=user, callbacks=callbacks, ) - if isinstance(llm_result, LLMResult): if not isinstance(llm_result.message.content, str): raise OutputParserError( f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}" ) - return LLMResultWithStructuredOutput( structured_output=_parse_structured_output(llm_result.message.content), model=llm_result.model, @@ -173,21 +160,18 @@ def invoke_llm_with_structured_output( if isinstance(event, LLMResultChunk): prompt_messages = event.prompt_messages system_fingerprint = event.system_fingerprint - if isinstance(event.delta.message.content, str): result_text += event.delta.message.content elif isinstance(event.delta.message.content, list): for item in event.delta.message.content: if isinstance(item, TextPromptMessageContent): result_text += item.data - yield LLMResultChunkWithStructuredOutput( model=model_schema.model, prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=event.delta, ) - yield LLMResultChunkWithStructuredOutput( structured_output=_parse_structured_output(result_text), model=model_schema.model, @@ -213,29 +197,24 @@ def _handle_native_json_schema( ) -> dict: """ Handle structured output for models with native JSON schema support. - :param model_parameters: Model parameters to update :param rules: Model parameter rules :return: Updated model parameters with JSON schema configuration """ # Process schema according to model requirements schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema) - # Set JSON schema in parameters model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) - # Set appropriate response format if required by the model for rule in rules: if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value - return model_parameters def _set_response_format(model_parameters: dict, rules: list) -> None: """ Set the appropriate response format parameter based on model rules. - :param model_parameters: Model parameters to update :param rules: Model parameter rules """ @@ -253,16 +232,13 @@ def _handle_prompt_based_schema( """ Handle structured output for models without native JSON schema support. This function modifies the prompt messages to include schema-based output requirements. - Args: prompt_messages: Original sequence of prompt messages - Returns: list[PromptMessage]: Updated prompt messages with structured output requirements """ # Convert schema to string format schema_str = json.dumps(structured_output_schema, ensure_ascii=False) - # Find existing system prompt with schema placeholder system_prompt = next( (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), @@ -276,12 +252,9 @@ def _handle_prompt_based_schema( else structured_output_prompt ) system_prompt = SystemPromptMessage(content=system_prompt_content) - # Extract content from the last user message - filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] updated_prompt = [system_prompt] + filtered_prompts - return updated_prompt @@ -309,20 +282,15 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]: def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: """ Prepare JSON schema based on model requirements. - Different models have different requirements for JSON schema formatting. This function handles these differences. - :param schema: The original JSON schema :return: Processed schema compatible with the current model """ - # Deep copy to avoid modifying the original schema processed_schema = dict(deepcopy(schema)) - # Convert boolean types to string types (common requirement) convert_boolean_to_string(processed_schema) - # Apply model-specific transformations if SpecialModelType.GEMINI in model_schema.model: remove_additional_properties(processed_schema) @@ -338,15 +306,12 @@ def remove_additional_properties(schema: dict) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. - :param schema: JSON schema to modify in-place """ if not isinstance(schema, dict): return - # Remove additionalProperties at current level schema.pop("additionalProperties", None) - # Process nested structures recursively for value in schema.values(): if isinstance(value, dict): @@ -360,16 +325,13 @@ def remove_additional_properties(schema: dict) -> None: def convert_boolean_to_string(schema: dict) -> None: """ Convert boolean type specifications to string in JSON schema. - :param schema: JSON schema to modify in-place """ if not isinstance(schema, dict): return - # Check for boolean type at current level if schema.get("type") == "boolean": schema["type"] = "string" - # Process nested dictionaries and lists recursively for value in schema.values(): if isinstance(value, dict): diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index c451bf514c..98cdc4c8b7 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser: json_obj = json.loads(action_match.group(0).strip()) else: json_obj = [] - return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ef81e38dc5..b9aa86fc94 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -1,14 +1,11 @@ # Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. - 1. Detect Input Language Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.). - 2. Generate Title - Combine Intention + Subject into a single, as-short-as-possible phrase. - The title must be natural, friendly, and in the same language as the input. - If the input is a direct question to the model, you may add an emoji at the end. - 3. Output Format Return **only** a valid JSON object with these exact keys and no additional text: { @@ -16,10 +13,8 @@ Return **only** a valid JSON object with these exact keys and no additional text "Your Reasoning": "", "Your Output": "" } - User Input: """ # noqa: E501 - PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = ( "You are an expert programmer. Generate code based on the following instructions:\n\n" "Instructions: {{INSTRUCTION}}\n\n" @@ -92,8 +87,6 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( "- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n" "Generated Code:\n" ) - - SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keep each question under 20 characters.\n" @@ -101,7 +94,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "The output must be an array in JSON format following the specified schema:\n" '["question1","question2","question3"]\n' ) - GENERATOR_QA_PROMPT = ( " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" " in the long text. Please think step by step." @@ -115,7 +107,6 @@ GENERATOR_QA_PROMPT = ( " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" "" ) - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ Here is a task description for which I would like you to create a high-quality prompt template for: @@ -130,7 +121,6 @@ Based on task description, please create a well-structured prompt template that - Output in ``` xml ``` and start with Please generate the full prompt template with at least 300 words and output only the prompt template. """ # noqa: E501 - RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ Here is a task description for which I would like you to create a high-quality prompt template for: @@ -145,32 +135,26 @@ Based on task description, please create a well-structured prompt template that - Output in ``` xml ``` and start with Please generate the full prompt template and output only the prompt template. """ # noqa: E501 - RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """ I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. variables name bounded two double curly brackets. Variable name has to be composed of number, english alphabets and underline and nothing else. - Step 1: Carefully read the input and understand the structure of the expected output. Step 2: Extract relevant parameters from the provided text based on the name and description of object. Step 3: Structure the extracted parameters to JSON object as specified in . Step 4: Ensure that the list of variable_names is properly formatted and valid. The output should not contain any XML tags. Output an empty list if there is no valid variable name in input text. - ### Structure Here is the structure of the expected output, I should always follow the output structure. ["variable_name_1", "variable_name_2"] - ### Input Text Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. {{INPUT_TEXT}} - ### Answer I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text. """ # noqa: E501 - RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """ Step 1: Identify the purpose of the chatbot from the variable {{TASK_DESCRIPTION}} and infer chatbot's tone (e.g., friendly, professional, etc.) to add personality traits. @@ -183,15 +167,11 @@ Example Output: Welcome! I'm here to assist you with any questions or issues you might have with your shopping experience. Whether you're looking for product information, need help with your order, or have any other inquiries, feel free to ask. I'm friendly, helpful, and ready to support you in any way I can. Here is the task description: {{INPUT_TEXT}} - You just need to generate the output """ # noqa: E501 - SYSTEM_STRUCTURED_OUTPUT_GENERATE = """ Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. - ## Instructions: - 1. Analyze the user's description of their data needs 2. Identify each property that should be included in the schema 3. Determine the appropriate data type for each property @@ -200,9 +180,7 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc 6. Include appropriate constraints when specified (min/max values, patterns, formats) 7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. 8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. - ## Examples: - ### Example 1: **User Input:** I need name and age **JSON Schema Output:** @@ -214,7 +192,6 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc }, "required": ["name", "age"] } - ### Example 2: **User Input:** I want to store information about books including title, author, publication year and optional page count **JSON Schema Output:** @@ -228,7 +205,6 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc }, "required": ["title", "author", "publicationYear"] } - ### Example 3: **User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) **JSON Schema Output:** @@ -250,7 +226,6 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc }, "required": ["email", "password", "age"] } - ### Example 4: **User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. **JSON Schema Output:** @@ -288,10 +263,8 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc "songs" ] } - Now, generate a JSON Schema based on my description """ # noqa: E501 - STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. constraints: - You must output in JSON format. @@ -300,10 +273,8 @@ constraints: eg: Here is the JSON schema: {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} - Here is the user's question: My name is John Doe and I am 30 years old. - output: {"name": "John Doe", "age": 30} Here is the JSON schema: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..77fc5ec498 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -34,7 +34,6 @@ class TokenBufferMemory: :param message_limit: message limit """ app_record = self.conversation.app - # fetch limited messages, and return reversed query = ( db.session.query( @@ -51,24 +50,18 @@ class TokenBufferMemory: ) .order_by(Message.created_at.desc()) ) - if message_limit and message_limit > 0: message_limit = min(message_limit, 500) else: message_limit = 500 - messages = query.limit(message_limit).all() - # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message thread_messages = extract_thread_messages(messages) - # for newly created message, its answer is temporarily empty, we don't need to add it to memory if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: thread_messages.pop(0) - messages = list(reversed(thread_messages)) - prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() @@ -81,12 +74,10 @@ class TokenBufferMemory: workflow_run = ( db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() ) - if workflow_run and workflow_run.workflow: file_extra_config = FileUploadConfigManager.convert( workflow_run.workflow.features_dict, is_vision=False ) - detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: file_objs = file_factory.build_from_message_files( @@ -96,7 +87,6 @@ class TokenBufferMemory: detail = file_extra_config.image_config.detail else: file_objs = [] - if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) else: @@ -108,26 +98,19 @@ class TokenBufferMemory: image_detail_config=detail, ) prompt_message_contents.append(prompt_message) - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: prompt_messages.append(UserPromptMessage(content=message.query)) - prompt_messages.append(AssistantPromptMessage(content=message.answer)) - if not prompt_messages: return [] - # prune the chat message if it exceeds the max token limit curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - if curr_message_tokens > max_token_limit: pruned_memory = [] while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - return prompt_messages def get_history_prompt_text( @@ -146,7 +129,6 @@ class TokenBufferMemory: :return: """ prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) - string_messages = [] for m in prompt_messages: if m.role == PromptMessageRole.USER: @@ -155,7 +137,6 @@ class TokenBufferMemory: role = ai_prefix else: continue - if isinstance(m.content, list): inner_msg = "" for content in m.content: @@ -163,10 +144,8 @@ class TokenBufferMemory: inner_msg += f"{content.data}\n" elif isinstance(content, ImagePromptMessageContent): inner_msg += "[image]\n" - string_messages.append(f"{role}: {inner_msg.strip()}") else: message = f"{role}: {m.content}" string_messages.append(message) - return "\n".join(string_messages) diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 57cad17285..0bcc3e73e2 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -38,7 +38,6 @@ class Callback(ABC): ) -> None: """ Before invoke callback - :param llm_instance: LLM instance :param model: model name :param credentials: model credentials @@ -67,7 +66,6 @@ class Callback(ABC): ): """ On new chunk callback - :param llm_instance: LLM instance :param chunk: chunk :param model: model name @@ -97,7 +95,6 @@ class Callback(ABC): ) -> None: """ After invoke callback - :param llm_instance: LLM instance :param result: result :param model: model name @@ -127,7 +124,6 @@ class Callback(ABC): ) -> None: """ Invoke error callback - :param llm_instance: LLM instance :param ex: exception :param model: model name diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 899f08195d..debeeb5435 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -27,7 +27,6 @@ class LoggingCallback(Callback): ) -> None: """ Before invoke callback - :param llm_instance: LLM instance :param model: model name :param credentials: model credentials @@ -43,28 +42,21 @@ class LoggingCallback(Callback): self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): self.print_text(f"\t{key}: {value}\n", color="blue") - if stop: self.print_text(f"\tstop: {stop}\n", color="blue") - if tools: self.print_text("\tTools:\n", color="blue") for tool in tools: self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color="blue") - if user: self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - if stream: self.print_text("\n[on_llm_new_chunk]") @@ -83,7 +75,6 @@ class LoggingCallback(Callback): ): """ On new chunk callback - :param llm_instance: LLM instance :param chunk: chunk :param model: model name @@ -113,7 +104,6 @@ class LoggingCallback(Callback): ) -> None: """ After invoke callback - :param llm_instance: LLM instance :param result: result :param model: model name @@ -127,14 +117,12 @@ class LoggingCallback(Callback): """ self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") - if result.message.tool_calls: self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: self.print_text(f"\t{tool_call.id}\n", color="yellow") self.print_text(f"\t{tool_call.function.name}\n", color="yellow") self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - self.print_text(f"Model: {result.model}\n", color="yellow") self.print_text(f"Usage: {result.usage}\n", color="yellow") self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") @@ -154,7 +142,6 @@ class LoggingCallback(Callback): ) -> None: """ Invoke error callback - :param llm_instance: LLM instance :param ex: exception :param model: model name diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index ace2c1f770..a5b7f9a389 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -57,10 +57,8 @@ class LLMUsage(ModelUsage): def from_metadata(cls, metadata: dict) -> "LLMUsage": """ Create LLMUsage instance from metadata dictionary with default values. - Args: metadata: Dictionary containing usage metadata - Returns: LLMUsage instance with values from metadata or defaults """ @@ -68,7 +66,6 @@ class LLMUsage(ModelUsage): completion_tokens = metadata.get("completion_tokens", 0) if total_tokens > 0 and completion_tokens == 0: completion_tokens = total_tokens - return cls( prompt_tokens=metadata.get("prompt_tokens", 0), completion_tokens=completion_tokens, @@ -87,7 +84,6 @@ class LLMUsage(ModelUsage): def plus(self, other: "LLMUsage") -> "LLMUsage": """ Add two LLMUsage instances together. - :param other: Another LLMUsage instance to add :return: A new LLMUsage instance with summed values """ @@ -112,7 +108,6 @@ class LLMUsage(ModelUsage): def __add__(self, other: "LLMUsage") -> "LLMUsage": """ Overload the + operator to add two LLMUsage instances. - :param other: Another LLMUsage instance to add :return: A new LLMUsage instance with summed values """ diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 9d010ae28d..5ad46e85b3 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -20,7 +20,6 @@ class PromptMessageRole(Enum): def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -128,8 +127,6 @@ PromptMessageContentUnionTypes = Annotated[ ], Field(discriminator="type"), ] - - CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { PromptMessageContentType.TEXT: TextPromptMessageContent, PromptMessageContentType.IMAGE: ImagePromptMessageContent, @@ -151,7 +148,6 @@ class PromptMessage(ABC, BaseModel): def is_empty(self) -> bool: """ Check if prompt message is empty. - :return: True if prompt message is empty, False otherwise """ return not self.content @@ -228,12 +224,10 @@ class AssistantPromptMessage(PromptMessage): def is_empty(self) -> bool: """ Check if prompt message is empty. - :return: True if prompt message is empty, False otherwise """ if not super().is_empty() and not self.tool_calls: return False - return True @@ -256,10 +250,8 @@ class ToolPromptMessage(PromptMessage): def is_empty(self) -> bool: """ Check if prompt message is empty. - :return: True if prompt message is empty, False otherwise """ if not super().is_empty() and not self.tool_call_id: return False - return True diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 568149cc37..6f0db55a19 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -23,7 +23,6 @@ class ModelType(Enum): def value_of(cls, origin_model_type: str) -> "ModelType": """ Get model type from origin model type. - :return: model type """ if origin_model_type in {"text-generation", cls.LLM.value}: @@ -44,7 +43,6 @@ class ModelType(Enum): def to_origin_model_type(self) -> str: """ Get origin model type from model type. - :return: origin model type """ if self == self.LLM: @@ -106,7 +104,6 @@ class DefaultParameterName(StrEnum): def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. - :param value: parameter value :return: parameter name """ diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index d0f9ee13e5..b0f38ad633 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -130,10 +130,8 @@ class ProviderEntity(BaseModel): models: list[AIModelEntity] = Field(default_factory=list) provider_credential_schema: Optional[ProviderCredentialSchema] = None model_credential_schema: Optional[ModelCredentialSchema] = None - # pydantic configs model_config = ConfigDict(protected_namespaces=()) - # position from plugin _position.yaml position: Optional[dict[str, list[str]]] = {} @@ -148,7 +146,6 @@ class ProviderEntity(BaseModel): def to_simple_provider(self) -> SimpleProviderEntity: """ Convert to simple provider. - :return: simple provider """ return SimpleProviderEntity( diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 7d5ce1e47e..2aaff435ca 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -39,7 +39,6 @@ class AIModel(BaseModel): provider_name: str = Field(description="Provider") plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") started_at: float = Field(description="Invoke start time", default=0) - # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -50,7 +49,6 @@ class AIModel(BaseModel): The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke error mapping """ return { @@ -66,7 +64,6 @@ class AIModel(BaseModel): def _transform_invoke_error(self, error: Exception) -> Exception: """ Transform invoke error to unified error - :param error: model invoke error :return: unified error """ @@ -82,13 +79,11 @@ class AIModel(BaseModel): return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") else: return error - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: """ Get price for given model and tokens - :param model: model name :param credentials: model credentials :param price_type: price type @@ -97,12 +92,10 @@ class AIModel(BaseModel): """ # get model schema model_schema = self.get_model_schema(model, credentials) - # get price info from predefined model schema price_config: Optional[PriceConfig] = None if model_schema and model_schema.pricing: price_config = model_schema.pricing - # get unit price unit_price = None if price_config: @@ -110,7 +103,6 @@ class AIModel(BaseModel): unit_price = price_config.input elif price_type == PriceType.OUTPUT and price_config.output is not None: unit_price = price_config.output - if unit_price is None: return PriceInfo( unit_price=decimal.Decimal("0.0"), @@ -118,13 +110,11 @@ class AIModel(BaseModel): total_amount=decimal.Decimal("0.0"), currency="USD", ) - # calculate total amount if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - return PriceInfo( unit_price=unit_price, unit=price_config.unit, @@ -135,7 +125,6 @@ class AIModel(BaseModel): def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials - :param model: model name :param credentials: model credentials :return: model schema @@ -145,17 +134,14 @@ class AIModel(BaseModel): # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - try: contexts.plugin_model_schemas.get() except LookupError: contexts.plugin_model_schemas.set({}) contexts.plugin_model_schema_lock.set(Lock()) - with contexts.plugin_model_schema_lock.get(): if cache_key in contexts.plugin_model_schemas.get(): return contexts.plugin_model_schemas.get()[cache_key] - schema = plugin_model_manager.get_model_schema( tenant_id=self.tenant_id, user_id="unknown", @@ -165,26 +151,21 @@ class AIModel(BaseModel): model=model, credentials=credentials or {}, ) - if schema: contexts.plugin_model_schemas.get()[cache_key] = schema - return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials - :param model: model name :param credentials: model credentials :return: model schema """ - # get customizable model schema schema = self.get_customizable_model_schema(model, credentials) if not schema: return None - # fill in the template new_parameter_rules = [] for parameter_rule in schema.parameter_rules: @@ -222,17 +203,13 @@ class AIModel(BaseModel): ) except ValueError: pass - new_parameter_rules.append(parameter_rule) - schema.parameter_rules = new_parameter_rules - return schema def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema - :param model: model name :param credentials: model credentials :return: model schema @@ -242,13 +219,10 @@ class AIModel(BaseModel): def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: """ Get default parameter rule for given name - :param name: parameter name :return: parameter rule """ default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - if not default_parameter_rule: raise Exception(f"Invalid model parameter rule name {name}") - return default_parameter_rule diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index e2cc576f83..950daa5ad0 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -36,7 +36,6 @@ def _increase_tool_call( ): """ Merge incremental tool call updates into existing tool calls. - :param new_tool_calls: List of new tool call deltas to be merged. :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. """ @@ -44,13 +43,11 @@ def _increase_tool_call( def get_tool_call(tool_call_id: str): """ Get or create a tool call by ID - :param tool_call_id: tool call ID :return: existing or new tool call """ if not tool_call_id: return existing_tools_calls[-1] - _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) if _tool_call is None: _tool_call = AssistantPromptMessage.ToolCall( @@ -59,7 +56,6 @@ def _increase_tool_call( function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) existing_tools_calls.append(_tool_call) - return _tool_call for new_tool_call in new_tool_calls: @@ -85,7 +81,6 @@ class LargeLanguageModel(AIModel): """ model_type: ModelType = ModelType.LLM - # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -103,7 +98,6 @@ class LargeLanguageModel(AIModel): ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model - :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -118,14 +112,10 @@ class LargeLanguageModel(AIModel): # validate and filter model parameters if model_parameters is None: model_parameters = {} - self.started_at = time.perf_counter() - callbacks = callbacks or [] - if dify_config.DEBUG: callbacks.append(LoggingCallback()) - # trigger before invoke callbacks self._trigger_before_invoke_callbacks( model=model, @@ -138,9 +128,7 @@ class LargeLanguageModel(AIModel): user=user, callbacks=callbacks, ) - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - try: plugin_model_manager = PluginModelClient() result = plugin_model_manager.invoke_llm( @@ -156,14 +144,12 @@ class LargeLanguageModel(AIModel): stop=list(stop) if stop else None, stream=stream, ) - if not stream: content = "" content_list = [] usage = LLMUsage.empty_usage() system_fingerprint = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - for chunk in result: if isinstance(chunk.delta.message.content, str): content += chunk.delta.message.content @@ -171,11 +157,9 @@ class LargeLanguageModel(AIModel): content_list.extend(chunk.delta.message.content) if chunk.delta.message.tool_calls: _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - usage = chunk.delta.usage or LLMUsage.empty_usage() system_fingerprint = chunk.system_fingerprint break - result = LLMResult( model=model, prompt_messages=prompt_messages, @@ -199,10 +183,8 @@ class LargeLanguageModel(AIModel): user=user, callbacks=callbacks, ) - # TODO raise self._transform_invoke_error(e) - if stream and isinstance(result, Generator): return self._invoke_result_generator( model=model, @@ -251,7 +233,6 @@ class LargeLanguageModel(AIModel): ) -> Generator[LLMResultChunk, None, None]: """ Invoke result generator - :param result: result generator :return: result generator """ @@ -278,7 +259,6 @@ class LargeLanguageModel(AIModel): # To ensure compatibility, we add the prompt_messages back here. chunk.prompt_messages = prompt_messages yield chunk - self._trigger_new_chunk_callbacks( chunk=chunk, model=model, @@ -291,18 +271,14 @@ class LargeLanguageModel(AIModel): user=user, callbacks=callbacks, ) - _update_message_content(chunk.delta.message.content) - real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage - if chunk.system_fingerprint: system_fingerprint = chunk.system_fingerprint except Exception as e: raise self._transform_invoke_error(e) - assistant_message = AssistantPromptMessage(content=message_content) self._trigger_after_invoke_callbacks( model=model, @@ -332,7 +308,6 @@ class LargeLanguageModel(AIModel): ) -> int: """ Get number of tokens for given prompt messages - :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -359,7 +334,6 @@ class LargeLanguageModel(AIModel): ) -> LLMUsage: """ Calculate response usage - :param model: model name :param credentials: model credentials :param prompt_tokens: prompt tokens @@ -373,12 +347,10 @@ class LargeLanguageModel(AIModel): price_type=PriceType.INPUT, tokens=prompt_tokens, ) - # get completion price info completion_price_info = self.get_price( model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) - # transform usage usage = LLMUsage( prompt_tokens=prompt_tokens, @@ -394,7 +366,6 @@ class LargeLanguageModel(AIModel): currency=prompt_price_info.currency, latency=time.perf_counter() - self.started_at, ) - return usage def _trigger_before_invoke_callbacks( @@ -411,7 +382,6 @@ class LargeLanguageModel(AIModel): ) -> None: """ Trigger before invoke callbacks - :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -457,7 +427,6 @@ class LargeLanguageModel(AIModel): ) -> None: """ Trigger new chunk callbacks - :param chunk: chunk :param model: model name :param credentials: model credentials @@ -504,7 +473,6 @@ class LargeLanguageModel(AIModel): ) -> None: """ Trigger after invoke callbacks - :param model: model name :param result: result :param credentials: model credentials @@ -552,7 +520,6 @@ class LargeLanguageModel(AIModel): ) -> None: """ Trigger invoke error callbacks - :param model: model name :param ex: exception :param credentials: model credentials diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 19dc1d599a..ae81e837db 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -14,14 +14,12 @@ class ModerationModel(AIModel): """ model_type: ModelType = ModelType.MODERATION - # pydantic configs model_config = ConfigDict(protected_namespaces=()) def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model - :param model: model name :param credentials: model credentials :param text: text to moderate @@ -29,7 +27,6 @@ class ModerationModel(AIModel): :return: false if text is safe, true otherwise """ self.started_at = time.perf_counter() - try: plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_moderation( diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 569e756a3b..c6e75a94fa 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -25,7 +25,6 @@ class RerankModel(AIModel): ) -> RerankResult: """ Invoke rerank model - :param model: model name :param credentials: model credentials :param query: search query diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index c69f65b681..a70a9b2089 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -13,14 +13,12 @@ class Speech2TextModel(AIModel): """ model_type: ModelType = ModelType.SPEECH2TEXT - # pydantic configs model_config = ConfigDict(protected_namespaces=()) def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech to text model - :param model: model name :param credentials: model credentials :param file: audio file diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index f7bba0eba1..3fbee3c281 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -15,7 +15,6 @@ class TextEmbeddingModel(AIModel): """ model_type: ModelType = ModelType.TEXT_EMBEDDING - # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -29,7 +28,6 @@ class TextEmbeddingModel(AIModel): ) -> TextEmbeddingResult: """ Invoke text embedding model - :param model: model name :param credentials: model credentials :param texts: texts to embed @@ -55,7 +53,6 @@ class TextEmbeddingModel(AIModel): def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: """ Get number of tokens for given prompt messages - :param model: model name :param credentials: model credentials :param texts: texts to embed @@ -75,31 +72,25 @@ class TextEmbeddingModel(AIModel): def _get_context_size(self, model: str, credentials: dict) -> int: """ Get context size for given embedding model - :param model: model name :param credentials: model credentials :return: context size """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] return content_size - return 1000 def _get_max_chunks(self, model: str, credentials: dict) -> int: """ Get max chunks for given embedding model - :param model: model name :param credentials: model credentials :return: max chunks """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] return max_chunks - return 1 diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index b7db0b78bc..6f77f57515 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -3,7 +3,6 @@ from threading import Lock from typing import Any logger = logging.getLogger(__name__) - _tokenizer: Any = None _lock = Lock() @@ -49,5 +48,4 @@ class GPT2Tokenizer: gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index d51831900c..cc2a12965c 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -17,7 +17,6 @@ class TTSModel(AIModel): """ model_type: ModelType = ModelType.TTS - # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -32,7 +31,6 @@ class TTSModel(AIModel): ) -> Iterable[bytes]: """ Invoke large language model - :param model: model name :param tenant_id: user tenant id :param credentials: model credentials @@ -59,7 +57,6 @@ class TTSModel(AIModel): def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: """ Retrieves the list of voices supported by a given text-to-speech (TTS) model. - :param language: The language for which the voices are requested. :param model: The name of the TTS model. :param credentials: The credentials required to access the TTS model. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index ad46f64ec3..1bbf4f090f 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -38,15 +38,12 @@ class ModelProviderFactory: def __init__(self, tenant_id: str) -> None: self.provider_position_map = {} - self.tenant_id = tenant_id self.plugin_model_manager = PluginModelClient() - if not self.provider_position_map: # get the path of current classes current_path = os.path.abspath(__file__) model_providers_path = os.path.dirname(current_path) - # get _position.yaml file path self.provider_position_map = get_provider_position_map(model_providers_path) @@ -57,18 +54,15 @@ class ModelProviderFactory: """ # Fetch plugin model providers plugin_providers = self.get_plugin_model_providers() - # Convert PluginModelProviderEntity to ModelProviderExtension model_provider_extensions = [] for provider in plugin_providers: model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider)) - sorted_extensions = sort_to_dict_by_position_map( position_map=self.provider_position_map, data=model_provider_extensions, name_func=lambda x: x.plugin_model_provider_entity.declaration.provider, ) - return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()] def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: @@ -82,22 +76,17 @@ class ModelProviderFactory: except LookupError: contexts.plugin_model_providers.set(None) contexts.plugin_model_providers_lock.set(Lock()) - with contexts.plugin_model_providers_lock.get(): plugin_model_providers = contexts.plugin_model_providers.get() if plugin_model_providers is not None: return plugin_model_providers - plugin_model_providers = [] contexts.plugin_model_providers.set(plugin_model_providers) - # Fetch plugin model providers plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - for provider in plugin_providers: provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider plugin_model_providers.append(provider) - return plugin_model_providers def get_provider_schema(self, provider: str) -> ProviderEntity: @@ -117,41 +106,33 @@ class ModelProviderFactory: """ if "/" not in provider: provider = str(ModelProviderID(provider)) - # fetch plugin model providers plugin_model_provider_entities = self.get_plugin_model_providers() - # get the provider plugin_model_provider_entity = next( (p for p in plugin_model_provider_entities if p.declaration.provider == provider), None, ) - if not plugin_model_provider_entity: raise ValueError(f"Invalid provider: {provider}") - return plugin_model_provider_entity def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: """ Validate provider credentials - :param provider: provider name :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. :return: """ # fetch plugin model provider plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - # get provider_credential_schema and validate credentials according to the rules provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema if not provider_credential_schema: raise ValueError(f"Provider {provider} does not have provider_credential_schema") - # validate provider credential schema validator = ProviderCredentialSchemaValidator(provider_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # validate the credentials, raise exception if validation failed self.plugin_model_manager.validate_provider_credentials( tenant_id=self.tenant_id, @@ -160,7 +141,6 @@ class ModelProviderFactory: provider=plugin_model_provider_entity.provider, credentials=filtered_credentials, ) - return filtered_credentials def model_credentials_validate( @@ -168,7 +148,6 @@ class ModelProviderFactory: ) -> dict: """ Validate model credentials - :param provider: provider name :param model_type: model type :param model: model name @@ -177,16 +156,13 @@ class ModelProviderFactory: """ # fetch plugin model provider plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - # get model_credential_schema and validate credentials according to the rules model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema if not model_credential_schema: raise ValueError(f"Provider {provider} does not have model_credential_schema") - # validate model credential schema validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # call validate_credentials method of model type to validate credentials, raise exception if validation failed self.plugin_model_manager.validate_model_credentials( tenant_id=self.tenant_id, @@ -197,7 +173,6 @@ class ModelProviderFactory: model=model, credentials=filtered_credentials, ) - return filtered_credentials def get_model_schema( @@ -211,17 +186,14 @@ class ModelProviderFactory: # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - try: contexts.plugin_model_schemas.get() except LookupError: contexts.plugin_model_schemas.set({}) contexts.plugin_model_schema_lock.set(Lock()) - with contexts.plugin_model_schema_lock.get(): if cache_key in contexts.plugin_model_schemas.get(): return contexts.plugin_model_schemas.get()[cache_key] - schema = self.plugin_model_manager.get_model_schema( tenant_id=self.tenant_id, user_id="unknown", @@ -231,10 +203,8 @@ class ModelProviderFactory: model=model, credentials=credentials or {}, ) - if schema: contexts.plugin_model_schemas.get()[cache_key] = schema - return schema def get_models( @@ -246,51 +216,39 @@ class ModelProviderFactory: ) -> list[SimpleProviderEntity]: """ Get all models for given model type - :param provider: provider name :param model_type: model type :param provider_configs: list of provider configs :return: list of models """ provider_configs = provider_configs or [] - # scan all providers plugin_model_provider_entities = self.get_plugin_model_providers() - # convert provider_configs to dict provider_credentials_dict = {} for provider_config in provider_configs: provider_credentials_dict[provider_config.provider] = provider_config.credentials - # traverse all model_provider_extensions providers = [] for plugin_model_provider_entity in plugin_model_provider_entities: # filter by provider if provider is present if provider and plugin_model_provider_entity.declaration.provider != provider: continue - # get provider schema provider_schema = plugin_model_provider_entity.declaration - model_types = provider_schema.supported_model_types if model_type: if model_type not in model_types: continue - model_types = [model_type] - all_model_type_models = [] for model_schema in provider_schema.models: if model_schema.model_type != model_type: continue - all_model_type_models.append(model_schema) - simple_provider_schema = provider_schema.to_simple_provider() simple_provider_schema.models.extend(all_model_type_models) - providers.append(simple_provider_schema) - return providers def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: @@ -307,7 +265,6 @@ class ModelProviderFactory: "provider_name": provider_name, "plugin_model_provider": self.get_plugin_model_provider(provider), } - if model_type == ModelType.LLM: return LargeLanguageModel(**init_params) # type: ignore elif model_type == ModelType.TEXT_EMBEDDING: @@ -331,11 +288,9 @@ class ModelProviderFactory: """ # get the provider schema provider_schema = self.get_provider_schema(provider) - if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: @@ -343,15 +298,12 @@ class ModelProviderFactory: else: if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US - if not file_name: raise ValueError(f"Provider {provider} does not have icon.") - image_mime_types = { "jpg": "image/jpeg", "jpeg": "image/jpeg", @@ -366,10 +318,8 @@ class ModelProviderFactory: "heif": "image/heif", "heic": "image/heic", } - extension = file_name.split(".")[-1] mime_type = image_mime_types.get(extension, "image/png") - # get icon bytes from plugin asset manager plugin_asset_manager = PluginAssetManager() return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 810a7c4c44..0ba35e47f0 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -12,20 +12,16 @@ class CommonValidator: if not credential_form_schema.show_on: need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema continue - all_show_on_match = True for show_on_object in credential_form_schema.show_on: if show_on_object.variable not in credentials: all_show_on_match = False break - if credentials[show_on_object.variable] != show_on_object.value: all_show_on_match = False break - if all_show_on_match: need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema validated_credentials = {} for credential_form_schema in need_validate_credential_form_schema_map.values(): @@ -33,7 +29,6 @@ class CommonValidator: result = self._validate_credential_form_schema(credential_form_schema, credentials) if result: validated_credentials[credential_form_schema.variable] = result - return validated_credentials def _validate_credential_form_schema( @@ -41,7 +36,6 @@ class CommonValidator: ) -> Union[str, bool, None]: """ Validate credential form schema - :param credential_form_schema: credential form schema :param credentials: credentials :return: validated credential form schema value @@ -60,10 +54,8 @@ class CommonValidator: else: # If default does not exist, skip return None - # Get the value corresponding to the variable from credentials value = cast(str, credentials[credential_form_schema.variable]) - # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: @@ -71,22 +63,17 @@ class CommonValidator: f"Variable {credential_form_schema.variable} length should not" f" greater than {credential_form_schema.max_length}" ) - # check the type of value if not isinstance(value, str): raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown if value.lower() not in {"true", "false"}: raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = value.lower() == "true" - return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index 7d1644d134..e686bcc9c8 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -11,17 +11,12 @@ class ModelCredentialSchemaValidator(CommonValidator): def validate_and_filter(self, credentials: dict) -> dict: """ Validate model credentials - :param credentials: model credentials :return: filtered credentials """ - if self.model_credential_schema is None: raise ValueError("Model credential schema is None") - # get the credential_form_schemas in provider_credential_schema credential_form_schemas = self.model_credential_schema.credential_form_schemas - credentials["__model_type"] = self.model_type.value - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index 6dff2428ca..bea8cf45a9 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -9,11 +9,9 @@ class ProviderCredentialSchemaValidator(CommonValidator): def validate_and_filter(self, credentials: dict) -> dict: """ Validate provider credentials - :param credentials: provider credentials :return: validated provider credentials """ # get the credential_form_schemas in provider_credential_schema credential_form_schemas = self.provider_credential_schema.credential_form_schemas - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index a5c11aeeba..62bee8493a 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -32,15 +32,12 @@ def isoformat(o: Union[datetime.date, datetime.time]) -> str: def decimal_encoder(dec_value: Decimal) -> Union[int, float]: """ Encodes a Decimal as int of there's no exponent, otherwise float - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) where a integer (but not int typed) is used. Encoding this as a float results in failed round-tripping between encode and parse. Our Id type is a prime example of this. - >>> decimal_encoder(Decimal("1.0")) 1.0 - >>> decimal_encoder(Decimal("1")) 1 """ @@ -191,13 +188,11 @@ def jsonable_encoder( ) ) return encoded_list - if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) for encoder, classes_tuple in encoders_by_class_tuples.items(): if isinstance(obj, classes_tuple): return encoder(obj) - try: data = dict(obj) except Exception as e: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index c65a3885fd..d09f8c387c 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -27,17 +27,14 @@ class ApiModeration(Moderation): def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. - :param tenant_id: the id of workspace :param config: the form config data :return: """ cls._validate_inputs_and_outputs_config(config, False) - api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - extension = cls._get_api_based_extension(tenant_id, api_based_extension_id) if not extension: raise ValueError("API-based Extension not found. Please check it again.") @@ -47,13 +44,10 @@ class ApiModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) - result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) @@ -63,13 +57,10 @@ class ApiModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) - result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) @@ -81,7 +72,6 @@ class ApiModeration(Moderation): if not extension: raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) - result = requestor.request(extension_point, params) return result @@ -92,5 +82,4 @@ class ApiModeration(Moderation): .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) - return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index d8c392d097..84f681b36d 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -43,7 +43,6 @@ class Moderation(Extensible, ABC): def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. - :param tenant_id: the id of workspace :param config: the form config data :return: @@ -56,7 +55,6 @@ class Moderation(Extensible, ABC): Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review on the user inputs and return the processed results. - :param inputs: user inputs :param query: query string (required in chat app) :return: @@ -69,7 +67,6 @@ class Moderation(Extensible, ABC): Moderation for outputs. When LLM outputs content, the front end will pass the output content (may be segmented) to this method for sensitive content review, and the output content will be shielded if the review fails. - :param text: LLM output content :return: """ @@ -81,32 +78,25 @@ class Moderation(Extensible, ABC): inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): raise ValueError("inputs_config must be a dict") - # outputs_config outputs_config = config.get("outputs_config") if not isinstance(outputs_config, dict): raise ValueError("outputs_config must be a dict") - inputs_config_enabled = inputs_config.get("enabled") outputs_config_enabled = outputs_config.get("enabled") if not inputs_config_enabled and not outputs_config_enabled: raise ValueError("At least one of inputs_config or outputs_config must be enabled") - # preset_response if not is_preset_response_required: return - if inputs_config_enabled: if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") - if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 0ad4438c14..26b6901680 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -14,7 +14,6 @@ class ModerationFactory: def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. - :param name: the name of extension :param tenant_id: the id of workspace :param config: the form config data @@ -30,7 +29,6 @@ class ModerationFactory: Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review on the user inputs and return the processed results. - :param inputs: user inputs :param query: query string (required in chat app) :return: @@ -42,7 +40,6 @@ class ModerationFactory: Moderation for outputs. When LLM outputs content, the front end will pass the output content (may be segmented) to this method for sensitive content review, and the output content will be shielded if the review fails. - :param text: LLM output content :return: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 3ac33966cb..69e21cf43e 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -37,17 +37,13 @@ class InputModeration: inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query - sensitive_word_avoidance_config = app_config.sensitive_word_avoidance moderation_type = sensitive_word_avoidance_config.type - moderation_factory = ModerationFactory( name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) - with measure_time() as timer: moderation_result = moderation_factory.moderation_for_inputs(inputs, query) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -58,14 +54,11 @@ class InputModeration: timer=timer, ) ) - if not moderation_result.flagged: return False, inputs, query - if moderation_result.action == ModerationAction.DIRECT_OUTPUT: raise ModerationError(moderation_result.preset_response) elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query - return True, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 9dd2665c3b..d292838139 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -11,19 +11,15 @@ class KeywordsModeration(Moderation): def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. - :param tenant_id: the id of workspace :param config: the form config data :return: """ cls._validate_inputs_and_outputs_config(config, True) - if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") - keywords_row_len = config["keywords"].split("\n") if len(keywords_row_len) > 100: raise ValueError("the number of rows for the keywords must be less than 100") @@ -33,18 +29,13 @@ class KeywordsModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] - if query: inputs["query__"] = query - # Filter out empty values keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) @@ -54,14 +45,11 @@ class KeywordsModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["outputs_config"]["enabled"]: # Filter out empty values keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({"text": text}, keywords_list) preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index d64f17b383..5dd80a6566 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -10,7 +10,6 @@ class OpenAIModeration(Moderation): def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. - :param tenant_id: the id of workspace :param config: the form config data :return: @@ -22,14 +21,11 @@ class OpenAIModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] - if query: inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) @@ -39,11 +35,9 @@ class OpenAIModeration(Moderation): preset_response = "" if self.config is None: raise ValueError("The config is not set.") - if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) @@ -54,7 +48,5 @@ class OpenAIModeration(Moderation): model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation(text=text) - return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 2ec315417f..25fdff9446 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -23,10 +23,8 @@ class ModerationRule(BaseModel): class OutputModeration(BaseModel): tenant_id: str app_id: str - rule: ModerationRule queue_manager: AppQueueManager - thread: Optional[threading.Thread] = None thread_running: bool = True buffer: str = "" @@ -42,24 +40,19 @@ class OutputModeration(BaseModel): def append_new_token(self, token: str) -> None: self.buffer += token - if not self.thread: self.thread = self.start_thread() def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]: self.buffer = completion self.is_final_chunk = True - result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) - if not result or not result.flagged: return completion, False - if result.action == ModerationAction.DIRECT_OUTPUT: final_output = result.preset_response else: final_output = result.text - if public_event: self.queue_manager.publish( QueueMessageReplaceEvent( @@ -67,7 +60,6 @@ class OutputModeration(BaseModel): ), PublishFrom.TASK_PIPELINE, ) - return final_output, True def start_thread(self) -> threading.Thread: @@ -79,9 +71,7 @@ class OutputModeration(BaseModel): "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) - thread.start() - return thread def stop_thread(self): @@ -99,22 +89,17 @@ class OutputModeration(BaseModel): if 0 <= chunk_length < buffer_size: time.sleep(1) continue - current_length = buffer_length - result = self.moderation( tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) - if not result or not result.flagged: continue - if result.action == ModerationAction.DIRECT_OUTPUT: final_output = result.preset_response self.final_output = final_output else: final_output = result.text + self.buffer[len(moderation_buffer) :] - # trigger replace event if self.thread_running: self.queue_manager.publish( @@ -123,7 +108,6 @@ class OutputModeration(BaseModel): ), PublishFrom.TASK_PIPELINE, ) - if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -132,10 +116,8 @@ class OutputModeration(BaseModel): moderation_factory = ModerationFactory( name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) - result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) return result except Exception as e: logger.exception(f"Moderation Output error, app_id: {app_id}") - return None diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 4e43561a15..bbb8b297c5 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -22,24 +22,19 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): Fetch app info """ app = cls._get_app(app_id, tenant_id) - """Retrieve app parameters.""" if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app.workflow if workflow is None: raise ValueError("unexpected app type") - features_dict = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: raise ValueError("unexpected app type") - features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) - return { "data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form), } @@ -64,19 +59,15 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): user = create_or_update_end_user_for_user_id(app) else: user = cls._get_user(user_id) - conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: if not query: raise ValueError("missing query") - return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) elif app.mode == AppMode.WORKFLOW: return cls.invoke_workflow_app(app, user, stream, inputs, files) elif app.mode == AppMode.COMPLETION: return cls.invoke_completion_app(app, user, stream, inputs, files) - raise ValueError("unexpected app type") @classmethod @@ -97,7 +88,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): workflow = app.workflow if not workflow: raise ValueError("unexpected app type") - return AdvancedChatAppGenerator().generate( app_model=app, workflow=workflow, @@ -155,7 +145,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): workflow = app.workflow if not workflow: raise ValueError("") - return WorkflowAppGenerator().generate( app_model=app, workflow=workflow, @@ -192,14 +181,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ get the user by user id """ - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() if not user: user = db.session.query(Account).filter(Account.id == user_id).first() - if not user: raise ValueError("user not found") - return user @classmethod @@ -211,8 +197,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() except Exception: raise ValueError("app not found") - if not app: raise ValueError("app not found") - return app diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a0..9ab3dabea7 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -12,7 +12,6 @@ class PluginEncrypter: provider_type=payload.namespace, provider_identity=payload.identity, ) - if payload.opt == "encrypt": return { "data": encrypter.encrypt(payload.data), diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index d07ab3d0c4..aa740567fa 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -47,7 +47,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model response = model_instance.invoke_llm( prompt_messages=payload.prompt_messages, @@ -57,7 +56,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): stream=True if payload.stream is None else payload.stream, user=user_id, ) - if isinstance(response, Generator): def handle() -> Generator[LLMResultChunk, None, None]: @@ -102,12 +100,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials) - if not model_schema: raise ValueError(f"Model schema not found for {payload.model}") - response = invoke_llm_with_structured_output( provider=payload.provider, model_schema=model_schema, @@ -120,7 +115,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): user=user_id, model_parameters=payload.completion_params, ) - if isinstance(response, Generator): def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: @@ -166,13 +160,11 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model response = model_instance.invoke_text_embedding( texts=payload.texts, user=user_id, ) - return response @classmethod @@ -186,7 +178,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model response = model_instance.invoke_rerank( query=payload.query, @@ -195,7 +186,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): top_n=payload.top_n, user=user_id, ) - return response @classmethod @@ -209,7 +199,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model response = model_instance.invoke_tts( content_text=payload.content_text, @@ -235,18 +224,15 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp: temp.write(unhexlify(payload.file)) temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( file=temp, user=user_id, ) - return { "result": response, } @@ -262,13 +248,11 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_type=payload.model_type, model=payload.model, ) - # invoke model response = model_instance.invoke_moderation( text=payload.text, user=user_id, ) - return { "result": response, } @@ -312,19 +296,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) content = payload.text - SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but retain the original meaning and keep the key points. however, the text you got is too long, what you got is possible a part of the text. Please summarize the text you got. - Here is the extra instruction you need to follow: {payload.instruction} """ - if ( cls.get_prompt_tokens( tenant_id=tenant.id, @@ -352,7 +333,6 @@ Here is the extra instruction you need to follow: UserPromptMessage(content=content), ], ) - assert isinstance(summary.message.content, str) return summary.message.content @@ -372,7 +352,6 @@ Here is the extra instruction you need to follow: new_lines.append(line) else: new_lines.append(line) - # merge lines into messages with max tokens messages: list[str] = [] for i in new_lines: # type: ignore @@ -385,15 +364,12 @@ Here is the extra instruction you need to follow: messages.append(i) # type: ignore else: messages[-1] += i # type: ignore - summaries = [] for i in range(len(messages)): message = messages[i] summary = summarize(message) summaries.append(summary) - result = "\n".join(summaries) - if ( cls.get_prompt_tokens( tenant_id=tenant.id, @@ -406,5 +382,4 @@ Here is the extra instruction you need to follow: tenant=tenant, payload=RequestInvokeSummary(text=result, instruction=payload.instruction), ) - return result diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 7898795ce2..7d39e8e8b7 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -30,7 +30,6 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) -> dict: """ Invoke parameter extractor node. - :param tenant_id: str :param user_id: str :param parameters: list[ParameterConfig] @@ -62,7 +61,6 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): f"{node_id}.query": query, }, ) - return { "inputs": execution.inputs, "outputs": execution.outputs, @@ -81,7 +79,6 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) -> dict: """ Invoke question classifier node. - :param tenant_id: str :param user_id: str :param model_config: ModelConfig @@ -111,7 +108,6 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): f"{node_id}.query": query, }, ) - return { "inputs": execution.inputs, "outputs": execution.outputs, diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 1d62743f13..4f600015a9 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -35,11 +35,9 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 ) - response = ToolFileMessageTransformer.transform_tool_invoke_messages( response, user_id=user_id, tenant_id=tenant_id ) - return response except Exception as e: raise e diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 2b438a3c33..d909e4c6d6 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -39,7 +39,6 @@ class PluginParameterType(enum.StrEnum): MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value - # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value @@ -94,7 +93,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): return "" else: return value if isinstance(value, str) else str(value) - case PluginParameterType.BOOLEAN: if value is None: return False @@ -110,7 +108,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): return bool(value) else: return value if isinstance(value, bool) else bool(value) - case PluginParameterType.NUMBER: if isinstance(value, int | float): return value @@ -156,11 +153,9 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An parameter_value = rule.default if not parameter_value and rule.required: raise ValueError(f"tool parameter {rule.name} not found in tool config") - if type == PluginParameterType.SELECT: # check if tool_parameter_config in options options = [x.value for x in rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError(f"tool parameter {rule.name} value {parameter_value} not in options {options}") - return cast_parameter_value(type, parameter_value) diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index bdf7d5ce1f..e161f661ca 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -166,7 +166,6 @@ class GenericProviderID: value = f"langgenius/{value}/{value}" else: raise ValueError(f"Invalid plugin id {value}") - self.organization, self.plugin_name, self.provider_name = value.split("/") self.is_hardcoded = is_hardcoded diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 592b42c0da..e02bd05633 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -66,7 +66,6 @@ class PluginBasicBooleanResponse(BaseModel): class PluginModelSchemaEntity(BaseModel): model_schema: AIModelEntity = Field(description="The model schema.") - # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index f9c81ed4d5..1a5df547b7 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -42,7 +42,6 @@ class BaseRequestInvokeModel(BaseModel): provider: str model: str model_type: ModelType - model_config = ConfigDict(protected_namespaces=()) @@ -58,7 +57,6 @@ class RequestInvokeLLM(BaseRequestInvokeModel): tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool]) stop: Optional[list[str]] = Field(default_factory=list[str]) stream: Optional[bool] = False - model_config = ConfigDict(protected_namespaces=()) @field_validator("prompt_messages", mode="before") @@ -66,7 +64,6 @@ class RequestInvokeLLM(BaseRequestInvokeModel): def convert_prompt_messages(cls, v): if not isinstance(v, list): raise ValueError("prompt_messages must be a list") - for i in range(len(v)): if v[i]["role"] == PromptMessageRole.USER.value: v[i] = UserPromptMessage(**v[i]) @@ -78,7 +75,6 @@ class RequestInvokeLLM(BaseRequestInvokeModel): v[i] = ToolPromptMessage(**v[i]) else: v[i] = PromptMessage(**v[i]) - return v diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 66b77c7489..6cb175c6c6 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -21,7 +21,6 @@ class PluginAgentClient(BasePluginClient): provider_name = declaration.get("identity", {}).get("name") for strategy in declaration.get("strategies", []): strategy["identity"]["provider"] = provider_name - return json_response response = self._request_with_plugin_daemon_response( @@ -31,14 +30,11 @@ class PluginAgentClient(BasePluginClient): params={"page": 1, "page_size": 256}, transformer=transformer, ) - for provider in response: provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name for strategy in provider.declaration.strategies: strategy.identity.provider = provider.declaration.identity.name - return response def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity: @@ -51,10 +47,8 @@ class PluginAgentClient(BasePluginClient): # skip if error occurs if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: return json_response - for strategy in json_response.get("data", {}).get("declaration", {}).get("strategies", []): strategy["identity"]["provider"] = agent_provider_id.provider_name - return json_response response = self._request_with_plugin_daemon_response( @@ -64,13 +58,10 @@ class PluginAgentClient(BasePluginClient): params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id}, transformer=transformer, ) - response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name for strategy in response.declaration.strategies: strategy.identity.provider = response.declaration.identity.name - return response def invoke( @@ -87,9 +78,7 @@ class PluginAgentClient(BasePluginClient): """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. """ - agent_provider_id = GenericProviderID(agent_provider) - response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/agent_strategy/invoke", diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7375726fa9..2e2aac8326 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -32,9 +32,7 @@ from core.plugin.impl.exc import ( ) plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) - T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) - logger = logging.getLogger(__name__) @@ -56,10 +54,8 @@ class BasePluginClient: headers = headers or {} headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY headers["Accept-Encoding"] = "gzip, deflate, br" - if headers.get("Content-Type") == "application/json" and isinstance(data, dict): data = json.dumps(data) - try: response = requests.request( method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files @@ -67,7 +63,6 @@ class BasePluginClient: except requests.exceptions.ConnectionError: logger.exception("Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") - return response def _stream_request( @@ -147,7 +142,6 @@ class BasePluginClient: msg = f"Failed to request plugin daemon, url: {path}" logging.exception(msg) raise ValueError(msg) from e - try: json_response = response.json() if transformer: @@ -160,18 +154,15 @@ class BasePluginClient: ) logging.exception(msg) raise ValueError(msg) - if rep.code != 0: try: error = PluginDaemonError(**json.loads(rep.message)) except Exception: raise ValueError(f"{rep.message}, code: {rep.code}") - self._handle_plugin_daemon_error(error.error_type, error.message) if rep.data is None: frame = inspect.currentframe() raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}") - return rep.data def _request_with_plugin_daemon_response_stream( @@ -200,14 +191,12 @@ class BasePluginClient: # for `ValueError`. # Otherwise, use the `line` to provide better contextual information about the error. raise ValueError(line_data.get("error", line)) - if rep.code != 0: if rep.code == -500: try: error = PluginDaemonError(**json.loads(rep.message)) except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) - self._handle_plugin_daemon_error(error.error_type, error.message) raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}") if rep.data is None: diff --git a/api/core/plugin/impl/debugging.py b/api/core/plugin/impl/debugging.py index 523377895c..b003e53d3a 100644 --- a/api/core/plugin/impl/debugging.py +++ b/api/core/plugin/impl/debugging.py @@ -13,5 +13,4 @@ class PluginDebuggingClient(BasePluginClient): key: str response = self._request_with_plugin_daemon_response("POST", f"plugin/{tenant_id}/debugging/key", Response) - return response.key diff --git a/api/core/plugin/impl/dynamic_select.py b/api/core/plugin/impl/dynamic_select.py index 004412afd7..f7ea6aac4b 100644 --- a/api/core/plugin/impl/dynamic_select.py +++ b/api/core/plugin/impl/dynamic_select.py @@ -38,8 +38,6 @@ class DynamicSelectClient(BasePluginClient): "Content-Type": "application/json", }, ) - for options in response: return options - raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'") diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py index 5b88742be5..ede3ff9da4 100644 --- a/api/core/plugin/impl/endpoint.py +++ b/api/core/plugin/impl/endpoint.py @@ -8,7 +8,6 @@ class PluginEndpointClient(BasePluginClient): ) -> bool: """ Create an endpoint for the given plugin. - Errors will be raised if any error occurs. """ return self._request_with_plugin_daemon_response( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index f7607eef8d..9e95309aaf 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -65,10 +65,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.model_schema - return None def validate_provider_credentials( @@ -93,13 +91,10 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: if resp.credentials and isinstance(resp.credentials, dict): credentials.update(resp.credentials) - return resp.result - return False def validate_model_credentials( @@ -133,13 +128,10 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: if resp.credentials and isinstance(resp.credentials, dict): credentials.update(resp.credentials) - return resp.result - return False def invoke_llm( @@ -184,7 +176,6 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - try: yield from response except PluginDaemonInnerError as e: @@ -227,10 +218,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.num_tokens - return 0 def invoke_text_embedding( @@ -269,10 +258,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp - raise ValueError("Failed to invoke text embedding") def get_text_embedding_num_tokens( @@ -309,10 +296,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.num_tokens - return [] def invoke_rerank( @@ -355,10 +340,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp - raise ValueError("Failed to invoke rerank") def invoke_tts( @@ -398,7 +381,6 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - try: for result in response: hex_str = result.result @@ -440,14 +422,11 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: voices = [] for voice in resp.voices: voices.append({"name": voice.name, "value": voice.value}) - return voices - return [] def invoke_speech_to_text( @@ -484,10 +463,8 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.result - raise ValueError("Failed to invoke speech to text") def invoke_moderation( @@ -524,8 +501,6 @@ class PluginModelClient(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.result - raise ValueError("Failed to invoke moderation") diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index b006bf1d4b..774021d190 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -49,10 +49,8 @@ class OAuthHandler(BasePluginClient): """ Get credentials from the given request. """ - # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", @@ -78,10 +76,8 @@ class OAuthHandler(BasePluginClient): def _convert_request_to_raw_data(self, request: Request) -> bytes: """ Convert a Request object to raw HTTP data. - Args: request: The Request object to convert. - Returns: The raw HTTP data as bytes. """ @@ -90,17 +86,13 @@ class OAuthHandler(BasePluginClient): path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() - # Add headers for header_name, header_value in request.headers.items(): raw_data += f"{header_name}: {header_value}\r\n".encode() - # Add empty line to separate headers from body raw_data += b"\r\n" - # Add body if exists body = request.get_data(as_text=False) if body: raw_data += body - return raw_data diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index b7f7b31655..a9567651a0 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -60,11 +60,9 @@ class PluginInstaller(BasePluginClient): body = { "dify_pkg": ("dify_pkg", pkg, "application/octet-stream"), } - data = { "verify_signature": "true" if verify_signature else "false", } - return self._request_with_plugin_daemon_response( "POST", f"plugin/{tenant_id}/management/install/upload/package", @@ -168,7 +166,6 @@ class PluginInstaller(BasePluginClient): """ Fetch a plugin manifest. """ - return self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/fetch/manifest", diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..010f029d22 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -21,7 +21,6 @@ class PluginToolManager(BasePluginClient): provider_name = declaration.get("identity", {}).get("name") for tool in declaration.get("tools", []): tool["identity"]["provider"] = provider_name - return json_response response = self._request_with_plugin_daemon_response( @@ -31,14 +30,11 @@ class PluginToolManager(BasePluginClient): params={"page": 1, "page_size": 256}, transformer=transformer, ) - for provider in response: provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name for tool in provider.declaration.tools: tool.identity.provider = provider.declaration.identity.name - return response def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: @@ -52,7 +48,6 @@ class PluginToolManager(BasePluginClient): if data: for tool in data.get("declaration", {}).get("tools", []): tool["identity"]["provider"] = tool_provider_id.provider_name - return json_response response = self._request_with_plugin_daemon_response( @@ -62,13 +57,10 @@ class PluginToolManager(BasePluginClient): params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, transformer=transformer, ) - response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name for tool in response.declaration.tools: tool.identity.provider = response.declaration.identity.name - return response def invoke( @@ -86,9 +78,7 @@ class PluginToolManager(BasePluginClient): """ Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters. """ - tool_provider_id = GenericProviderID(tool_provider) - response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/tool/invoke", @@ -134,11 +124,9 @@ class PluginToolManager(BasePluginClient): total_length = resp.message.total_length blob_data = resp.message.blob is_end = resp.message.end - # Initialize buffer for this file if it doesn't exist if chunk_id not in files: files[chunk_id] = FileChunk(total_length) - # If this is the final chunk, yield a complete blob message if is_end: yield ToolInvokeMessage( @@ -153,12 +141,10 @@ class PluginToolManager(BasePluginClient): del files[chunk_id] # Skip yielding this message raise ValueError("File is too large which reached the limit of 30MB") - # Check if single chunk is too large (8KB limit) if len(blob_data) > 8192: # Skip yielding this message raise ValueError("File chunk is too large which reached the limit of 8KB") - # Append the blob data to the buffer files[chunk_id].data[ files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data) @@ -174,7 +160,6 @@ class PluginToolManager(BasePluginClient): validate the credentials of the provider """ tool_provider_id = GenericProviderID(provider) - response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/tool/validate_credentials", @@ -191,10 +176,8 @@ class PluginToolManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.result - return False def get_runtime_parameters( @@ -236,8 +219,6 @@ class PluginToolManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: return resp.parameters - return [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 488a394679..05ba5ab726 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -60,7 +60,6 @@ class ProviderManager: def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ Get model provider configurations. - Construct ProviderConfiguration objects for each provider Including: 1. Basic information of the provider @@ -90,18 +89,15 @@ class ProviderManager: Append custom provider models to the list - Get provider instance - Switch selection priority - :param tenant_id: :return: """ # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) - # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( tenant_id, provider_name_to_provider_records_dict ) - # append providers with langgenius/openai/openai provider_name_list = list(provider_name_to_provider_records_dict.keys()) for provider_name in provider_name_list: @@ -110,7 +106,6 @@ class ProviderManager: provider_name_to_provider_records_dict[str(provider_id)] = provider_name_to_provider_records_dict[ provider_name ] - # Get all provider model records of the workspace provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id) for provider_name in list(provider_name_to_provider_model_records_dict.keys()): @@ -119,11 +114,9 @@ class ProviderManager: provider_name_to_provider_model_records_dict[str(provider_id)] = ( provider_name_to_provider_model_records_dict[provider_name] ) - # Get all provider entities model_provider_factory = ModelProviderFactory(tenant_id) provider_entities = model_provider_factory.get_providers() - # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) # Ensure that both the original provider name and its ModelProviderID string representation @@ -135,17 +128,13 @@ class ProviderManager: provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( provider_name_to_preferred_model_provider_records_dict[provider_name] ) - # Get All provider model settings provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) - # Get All load balancing configs provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( tenant_id ) - provider_configurations = ProviderConfigurations(tenant_id=tenant_id) - # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: # handle include, exclude @@ -156,7 +145,6 @@ class ProviderManager: name_func=lambda x: x.provider, ): continue - provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) @@ -165,18 +153,14 @@ class ProviderManager: provider_model_records.extend( provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) ) - # Convert to custom configuration custom_configuration = self._to_custom_configuration( tenant_id, provider_entity, provider_records, provider_model_records ) - # Convert to system configuration system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) - # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) - if preferred_provider_type_record: preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) elif custom_configuration.provider or custom_configuration.models: @@ -185,29 +169,22 @@ class ProviderManager: preferred_provider_type = ProviderType.SYSTEM else: preferred_provider_type = ProviderType.CUSTOM - using_provider_type = preferred_provider_type has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) - if preferred_provider_type == ProviderType.SYSTEM: if not system_configuration.enabled or not has_valid_quota: using_provider_type = ProviderType.CUSTOM - else: if not custom_configuration.provider and not custom_configuration.models: if system_configuration.enabled and has_valid_quota: using_provider_type = ProviderType.SYSTEM - # Get provider load balancing configs provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) - # Get provider load balancing configs provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( provider_name ) - provider_id_entity = ModelProviderID(provider_name) - if provider_id_entity.is_langgenius(): if provider_model_settings is not None: provider_model_settings.extend( @@ -219,14 +196,12 @@ class ProviderManager: provider_id_entity.provider_name, [] ) ) - # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, load_balancing_model_configs=provider_load_balancing_configs, ) - provider_configuration = ProviderConfiguration( tenant_id=tenant_id, provider=provider_entity, @@ -236,9 +211,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) - provider_configurations[str(provider_id_entity)] = provider_configuration - # Return the encapsulated object return provider_configurations @@ -251,14 +224,11 @@ class ProviderManager: :return: """ provider_configurations = self.get_configurations(tenant_id) - # get provider instance provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - model_type_instance = provider_configuration.get_model_type_instance(model_type) - return ProviderModelBundle( configuration=provider_configuration, model_type_instance=model_type_instance, @@ -267,7 +237,6 @@ class ProviderManager: def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: """ Get default model. - :param tenant_id: workspace id :param model_type: model type :return: @@ -281,21 +250,17 @@ class ProviderManager: ) .first() ) - # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record if not default_model: # Get provider configurations provider_configurations = self.get_configurations(tenant_id) - # get available models from provider_configurations available_models = provider_configurations.get_models(model_type=model_type, only_active=True) - if available_models: available_model = next( (model for model in available_models if model.model == "gpt-4"), available_models[0] ) - default_model = TenantDefaultModel() default_model.tenant_id = tenant_id default_model.model_type = model_type.to_origin_model_type() @@ -303,13 +268,10 @@ class ProviderManager: default_model.model_name = available_model.model db.session.add(default_model) db.session.commit() - if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) - return DefaultModelEntity( model=default_model.model_name, model_type=model_type, @@ -325,19 +287,15 @@ class ProviderManager: def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ Get names of first model and its provider - :param tenant_id: workspace id :param model_type: model type :return: provider name, model name """ provider_configurations = self.get_configurations(tenant_id) - # get available models from provider_configurations all_models = provider_configurations.get_models(model_type=model_type, only_active=False) - if not all_models: return None, None - return all_models[0].provider.provider, all_models[0].model def update_default_model_record( @@ -345,7 +303,6 @@ class ProviderManager: ) -> TenantDefaultModel: """ Update default model record. - :param tenant_id: workspace id :param model_type: model type :param provider: provider name @@ -355,15 +312,12 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) if provider not in provider_configurations: raise ValueError(f"Provider {provider} does not exist.") - # get available models from provider_configurations available_models = provider_configurations.get_models(model_type=model_type, only_active=True) - # check if the model is exist in available models model_names = [model.model for model in available_models] if model not in model_names: raise ValueError(f"Model {model} does not exist.") - # Get the list of available models from get_configurations and check if it is LLM default_model = ( db.session.query(TenantDefaultModel) @@ -373,7 +327,6 @@ class ProviderManager: ) .first() ) - # create or update TenantDefaultModel record if default_model: # update default model @@ -390,7 +343,6 @@ class ProviderManager: ) db.session.add(default_model) db.session.commit() - return default_model @staticmethod @@ -408,7 +360,6 @@ class ProviderManager: def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. - :param tenant_id: workspace id :return: """ @@ -424,7 +375,6 @@ class ProviderManager: def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. - :param tenant_id: workspace id :return: """ @@ -442,7 +392,6 @@ class ProviderManager: def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. - :param tenant_id: workspace id :return: """ @@ -460,7 +409,6 @@ class ProviderManager: def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. - :param tenant_id: workspace id :return: """ @@ -472,10 +420,8 @@ class ProviderManager: else: cache_result = cache_result.decode("utf-8") model_load_balancing_enabled = cache_result == "True" - if not model_load_balancing_enabled: return {} - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) with Session(db.engine, expire_on_commit=False) as session: stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) @@ -484,7 +430,6 @@ class ProviderManager: provider_name_to_provider_load_balancing_model_configs_dict[ provider_load_balancing_config.provider_name ].append(provider_load_balancing_config) - return provider_name_to_provider_load_balancing_model_configs_dict @staticmethod @@ -493,31 +438,25 @@ class ProviderManager: ) -> dict[str, list[Provider]]: """ Initialize trial provider records if not exists. - :param tenant_id: workspace id :param provider_name_to_provider_records_dict: provider name to provider records dict :return: """ # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - for provider_name, configuration in hosting_configuration.provider_map.items(): if not configuration.enabled: continue - provider_records = provider_name_to_provider_records_dict.get(provider_name) if not provider_records: provider_records = [] - provider_quota_to_provider_record_dict = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: # Init trial provider records if not exists @@ -551,13 +490,10 @@ class ProviderManager: ) if not existed_provider_record: continue - if not existed_provider_record.is_valid: existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) - return provider_name_to_provider_records_dict def _to_custom_configuration( @@ -569,7 +505,6 @@ class ProviderManager: ) -> CustomConfiguration: """ Convert to custom configuration. - :param tenant_id: workspace id :param provider_entity: provider entity :param provider_records: provider records @@ -582,18 +517,14 @@ class ProviderManager: if provider_entity.provider_credential_schema else [] ) - # Get custom provider record custom_provider_record = None for provider_record in provider_records: if provider_record.provider_type == ProviderType.SYSTEM.value: continue - if not provider_record.encrypted_config: continue - custom_provider_record = provider_record - # Get custom provider credentials custom_provider_configuration = None if custom_provider_record: @@ -602,10 +533,8 @@ class ProviderManager: identity_id=custom_provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) - # Get cached provider credentials cached_provider_credentials = provider_credentials_cache.get() - if not cached_provider_credentials: try: # fix origin data @@ -617,11 +546,9 @@ class ProviderManager: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: provider_credentials = {} - # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - for variable in provider_credential_secret_variables: if variable in provider_credentials: try: @@ -632,44 +559,35 @@ class ProviderManager: ) except ValueError: pass - # cache provider credentials provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) - # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas if provider_entity.model_credential_schema else [] ) - # Get custom provider model credentials custom_model_configurations = [] for provider_model_record in provider_model_records: if not provider_model_record.encrypted_config: continue - provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) - # Get cached provider model credentials cached_provider_model_credentials = provider_model_credentials_cache.get() - if not cached_provider_model_credentials: try: provider_model_credentials = json.loads(provider_model_record.encrypted_config) except JSONDecodeError: continue - # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - for variable in model_credential_secret_variables: if variable in provider_model_credentials: try: @@ -680,12 +598,10 @@ class ProviderManager: ) except ValueError: pass - # cache provider model credentials provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - custom_model_configurations.append( CustomModelConfiguration( model=provider_model_record.model_name, @@ -693,7 +609,6 @@ class ProviderManager: credentials=provider_model_credentials, ) ) - return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) def _to_system_configuration( @@ -701,7 +616,6 @@ class ProviderManager: ) -> SystemConfiguration: """ Convert to system configuration. - :param tenant_id: workspace id :param provider_entity: provider entity :param provider_records: provider records @@ -709,17 +623,14 @@ class ProviderManager: """ # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: return SystemConfiguration(enabled=False) - # Convert provider_records to dict quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) @@ -739,12 +650,10 @@ class ProviderManager: continue else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] - if provider_record.quota_used is None: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") - quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -754,45 +663,35 @@ class ProviderManager: or provider_record.quota_limit == -1, restrict_models=provider_quota.restrict_models, ) - quota_configurations.append(quota_configuration) - if len(quota_configurations) == 0: return SystemConfiguration(enabled=False) - current_quota_type = self._choice_current_using_quota_type(quota_configurations) - current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) - # Get cached provider credentials # error occurs cached_provider_credentials = provider_credentials_cache.get() - if not cached_provider_credentials: provider_credentials: dict[str, Any] = {} if provider_records and provider_records[0].encrypted_config: provider_credentials = json.loads(provider_records[0].encrypted_config) - # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas if provider_entity.provider_credential_schema else [] ) - # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - for variable in provider_credential_secret_variables: if variable in provider_credentials: try: @@ -803,9 +702,7 @@ class ProviderManager: ) except ValueError: pass - current_using_credentials = provider_credentials or {} - # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) else: @@ -813,7 +710,6 @@ class ProviderManager: else: current_using_credentials = {} quota_configurations = [] - return SystemConfiguration( enabled=True, current_quota_type=current_quota_type, @@ -827,7 +723,6 @@ class ProviderManager: Choice current using quota type. paid quotas > provider free quotas > hosting trial quotas If there is still quota for the corresponding quota type according to the sorting, - :param quota_configurations: :return: """ @@ -835,24 +730,20 @@ class ProviderManager: quota_type_to_quota_configuration_dict = { quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } - last_quota_configuration = None for quota_type in [ProviderQuotaType.PAID, ProviderQuotaType.FREE, ProviderQuotaType.TRIAL]: if quota_type in quota_type_to_quota_configuration_dict: last_quota_configuration = quota_type_to_quota_configuration_dict[quota_type] if last_quota_configuration.is_valid: return quota_type - if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError("No quota type available") @staticmethod def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. - :param credential_form_schemas: :return: """ @@ -860,7 +751,6 @@ class ProviderManager: for credential_form_schema in credential_form_schemas: if credential_form_schema.type == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) - return secret_input_form_variables def _to_model_settings( @@ -889,11 +779,9 @@ class ProviderManager: if provider_entity.model_credential_schema else [] ) - model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings - for provider_model_setting in provider_model_settings: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: @@ -904,7 +792,6 @@ class ProviderManager: ): if not load_balancing_model_config.enabled: continue - if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": load_balancing_configs.append( @@ -915,28 +802,23 @@ class ProviderManager: ) ) continue - provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) - # Get cached provider model credentials cached_provider_model_credentials = provider_model_credentials_cache.get() - if not cached_provider_model_credentials: try: provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) except JSONDecodeError: continue - # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( load_balancing_model_config.tenant_id ) - for variable in model_credential_secret_variables: if variable in provider_model_credentials: try: @@ -947,12 +829,10 @@ class ProviderManager: ) except ValueError: pass - # cache provider model credentials provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append( ModelLoadBalancingConfiguration( id=load_balancing_model_config.id, @@ -960,7 +840,6 @@ class ProviderManager: credentials=provider_model_credentials, ) ) - model_settings.append( ModelSettings( model=provider_model_setting.model_name, @@ -969,5 +848,4 @@ class ProviderManager: load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) - return model_settings diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 35e16b5c8f..890f8b74c7 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from models.model import File - from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -41,7 +40,6 @@ class Tool(ABC): def tool_provider_type(self) -> ToolProviderType: """ get the tool provider type - :return: the tool provider type """ @@ -55,10 +53,8 @@ class Tool(ABC): ) -> Generator[ToolInvokeMessage]: if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) - # try parse tool parameters into the correct type tool_parameters = self._transform_tool_parameters_type(tool_parameters) - result = self._invoke( user_id=user_id, tool_parameters=tool_parameters, @@ -66,7 +62,6 @@ class Tool(ABC): app_id=app_id, message_id=message_id, ) - if isinstance(result, ToolInvokeMessage): def single_generator() -> Generator[ToolInvokeMessage, None, None]: @@ -91,7 +86,6 @@ class Tool(ABC): for parameter in self.entity.parameters or []: if parameter.name in tool_parameters: result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) - return result @abstractmethod @@ -113,9 +107,7 @@ class Tool(ABC): ) -> list[ToolParameter]: """ get the runtime parameters - interface for developer to dynamic change the parameters of a tool depends on the variables pool - :return: the runtime parameters """ return self.entity.parameters @@ -128,14 +120,12 @@ class Tool(ABC): ) -> list[ToolParameter]: """ get merged runtime parameters - :return: merged runtime parameters """ parameters = self.entity.parameters parameters = parameters.copy() user_parameters = self.get_runtime_parameters() or [] user_parameters = user_parameters.copy() - # override parameters for parameter in user_parameters: # check if parameter in tool parameters @@ -152,7 +142,6 @@ class Tool(ABC): else: # add new parameter parameters.append(parameter) - return parameters def create_image_message( @@ -161,7 +150,6 @@ class Tool(ABC): ) -> ToolInvokeMessage: """ create an image message - :param image: the url of the image :return: the image message """ @@ -179,7 +167,6 @@ class Tool(ABC): def create_link_message(self, link: str) -> ToolInvokeMessage: """ create a link message - :param link: the url of the link :return: the link message """ @@ -190,7 +177,6 @@ class Tool(ABC): def create_text_message(self, text: str) -> ToolInvokeMessage: """ create a text message - :param text: the text :return: the text message """ @@ -202,7 +188,6 @@ class Tool(ABC): def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: """ create a blob message - :param blob: the blob :param meta: the meta info of blob object :return: the blob message diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index d096fc7df7..240de80b47 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -20,7 +20,6 @@ class ToolProviderController(ABC): def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider - :return: the credentials schema """ return deepcopy(self.entity.credentials_schema) @@ -29,7 +28,6 @@ class ToolProviderController(ABC): def get_tool(self, tool_name: str) -> Tool: """ returns a tool that the provider can provide - :return: tool """ pass @@ -38,7 +36,6 @@ class ToolProviderController(ABC): def provider_type(self) -> ToolProviderType: """ returns the type of the provider - :return: type of the provider """ return ToolProviderType.BUILT_IN @@ -46,55 +43,43 @@ class ToolProviderController(ABC): def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool """ credentials_schema = dict[str, ProviderConfig]() if credentials_schema is None: return - for credential in self.entity.credentials_schema: credentials_schema[credential.name] = credential - credentials_need_to_validate: dict[str, ProviderConfig] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] - for credential_name in credentials: if credential_name not in credentials_need_to_validate: raise ToolProviderCredentialValidationError( f"credential {credential_name} not found in provider {self.entity.identity.name}" ) - # check type credential_schema = credentials_need_to_validate[credential_name] if not credential_schema.required and credentials[credential_name] is None: continue - if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - elif credential_schema.type == ProviderConfig.Type.SELECT: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - options = credential_schema.options if not isinstance(options, list): raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") - if credentials[credential_name] not in [x.value for x in options]: raise ToolProviderCredentialValidationError( f"credential {credential_name} should be one of {options}" ) - credentials_need_to_validate.pop(credential_name) - for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") - # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default @@ -105,5 +90,4 @@ class ToolProviderController(ABC): ProviderConfig.Type.SELECT, }: default_value = str(default_value) - credentials[credential_name] = default_value diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7e..ef957a0fe2 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -20,7 +20,6 @@ class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: self.tools = [] - # load provider yaml provider = self.__class__.__module__.split(".")[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml") @@ -28,24 +27,20 @@ class BuiltinToolProviderController(ToolProviderController): provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name for credential_name in provider_yaml["credentials_for_provider"]: provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - credentials_schema = [] for credential in provider_yaml.get("credentials_for_provider", {}): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) - super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, ), ) - self._load_tools() def _load_tools(self): @@ -58,7 +53,6 @@ class BuiltinToolProviderController(ToolProviderController): # get tool name tool_name = tool_file.split(".")[0] tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) - # get tool class, import the module assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source( module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}", @@ -80,13 +74,11 @@ class BuiltinToolProviderController(ToolProviderController): runtime=ToolRuntime(tenant_id=""), ) ) - self.tools = tools def _get_builtin_tools(self) -> list[BuiltinTool]: """ returns a list of tools that the provider can provide - :return: list of tools """ return self.tools @@ -94,18 +86,15 @@ class BuiltinToolProviderController(ToolProviderController): def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider - :return: the credentials schema """ if not self.entity.credentials_schema: return [] - return self.entity.credentials_schema.copy() def get_tools(self) -> list[BuiltinTool]: """ returns a list of tools that the provider can provide - :return: list of tools """ return self._get_builtin_tools() @@ -120,7 +109,6 @@ class BuiltinToolProviderController(ToolProviderController): def need_credentials(self) -> bool: """ returns whether the provider needs credentials - :return: whether the provider needs credentials """ return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 @@ -129,7 +117,6 @@ class BuiltinToolProviderController(ToolProviderController): def provider_type(self) -> ToolProviderType: """ returns the type of the provider - :return: type of the provider """ return ToolProviderType.BUILT_IN @@ -138,7 +125,6 @@ class BuiltinToolProviderController(ToolProviderController): def tool_labels(self) -> list[str]: """ returns the labels of the provider - :return: labels of the provider """ label_enums = self._get_tool_labels() @@ -153,13 +139,11 @@ class BuiltinToolProviderController(ToolProviderController): def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider - :param user_id: use id :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) - # validate credentials self._validate_credentials(user_id, credentials) @@ -167,7 +151,6 @@ class BuiltinToolProviderController(ToolProviderController): def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider - :param user_id: use id :param credentials: the credentials of the tool """ diff --git a/api/core/tools/builtin_tool/providers/_positions.py b/api/core/tools/builtin_tool/providers/_positions.py index 44a90db038..84166d68b6 100644 --- a/api/core/tools/builtin_tool/providers/_positions.py +++ b/api/core/tools/builtin_tool/providers/_positions.py @@ -16,5 +16,4 @@ class BuiltinToolProviderSort: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 5c24920871..2051869002 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -61,12 +61,10 @@ class ASRTool(BuiltinTool): message_id: Optional[str] = None, ) -> list[ToolParameter]: parameters = [] - options = [] for provider, model in self.get_available_models(): option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) options.append(option) - parameters.append( ToolParameter( name="model", diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f191968812..da5ee040c7 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -48,7 +48,6 @@ class TTSTool(BuiltinTool): buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) - wav_bytes = buffer.getvalue() yield self.create_text_message("Audio generated successfully") yield self.create_blob_message( @@ -77,7 +76,6 @@ class TTSTool(BuiltinTool): message_id: Optional[str] = None, ) -> list[ToolParameter]: parameters = [] - options = [] for provider, model, voices in self.get_available_models(): option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) @@ -96,7 +94,6 @@ class TTSTool(BuiltinTool): ], ) ) - parameters.insert( 0, ToolParameter( diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py index b4e650e0ed..4a033f0e4b 100644 --- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py +++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py @@ -19,13 +19,10 @@ class SimpleCode(BuiltinTool): """ invoke simple code """ - language = tool_parameters.get("language", CodeLanguage.PYTHON3) code = tool_parameters.get("code", "") - if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: raise ValueError(f"Only python3 and javascript are supported, not {language}") - try: result = CodeExecutor.execute_code(language, "", code) yield self.create_text_message(result) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index d054afac96..c7a6ad5c5a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -26,7 +26,6 @@ class CurrentTimeTool(BuiltinTool): if tz == "UTC": yield self.create_text_message(f"{datetime.now(UTC).strftime(fm)}") return - try: tz = pytz_timezone(tz) except Exception: diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 1639dd687f..d4d6e6cc72 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -26,12 +26,10 @@ class LocaltimeToTimestampTool(BuiltinTool): if not timezone: timezone = None time_format = "%Y-%m-%d %H:%M:%S" - timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore if not timestamp: yield self.create_text_message(f"Invalid localtime: {localtime}") return - yield self.create_text_message(f"{timestamp}") @staticmethod diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 0ef6331530..8604ef8e0e 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -26,14 +26,11 @@ class TimestampToLocaltimeTool(BuiltinTool): if not timezone: timezone = None time_format = "%Y-%m-%d %H:%M:%S" - locatime = self.timestamp_to_localtime(timestamp, timezone) if not locatime: yield self.create_text_message(f"Invalid timestamp: {timestamp}") return - localtime_format = locatime.strftime(time_format) - yield self.create_text_message(f"{localtime_format}") @staticmethod diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index f9b776b3b9..a2fece13a7 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -30,7 +30,6 @@ class TimezoneConversionTool(BuiltinTool): f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" ) return - yield self.create_text_message(f"{target_time}") @staticmethod diff --git a/api/core/tools/builtin_tool/providers/time/tools/weekday.py b/api/core/tools/builtin_tool/providers/time/tools/weekday.py index 158ce701c0..4098402a58 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/weekday.py +++ b/api/core/tools/builtin_tool/providers/time/tools/weekday.py @@ -24,12 +24,10 @@ class WeekdayTool(BuiltinTool): if month is None: raise ValueError("Month is required") day = tool_parameters.get("day") - date_obj = self.convert_datetime(year, month, day) if not date_obj: yield self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") return - weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" @@ -41,7 +39,6 @@ class WeekdayTool(BuiltinTool): # allowed range in datetime module if not (year >= 1 and 1 <= month <= 12 and 1 <= day <= 31): return None - year = int(year) month = int(month) day = int(day) diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py index 3bee710879..fee6e058bd 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -25,10 +25,8 @@ class WebscraperTool(BuiltinTool): if not url: yield self.create_text_message("Please input url") return - # get webpage result = get_url(url, user_agent=user_agent) - if tool_parameters.get("generate_summary"): # summarize and return yield self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 724a2291c6..ec5b5b4f4b 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -16,7 +16,6 @@ Please summarize the text you got. class BuiltinTool(Tool): """ Builtin tool - :param meta: the meta data of a tool call processing """ @@ -40,7 +39,6 @@ class BuiltinTool(Tool): def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ invoke model - :param user_id: the user id :param prompt_messages: the prompt messages :param stop: the stop words @@ -61,12 +59,10 @@ class BuiltinTool(Tool): def get_max_tokens(self) -> int: """ get max tokens - :return: the max tokens """ if self.runtime is None: raise ValueError("runtime is required") - return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", ) @@ -74,20 +70,17 @@ class BuiltinTool(Tool): def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ get prompt tokens - :param prompt_messages: the prompt messages :return: the tokens """ if self.runtime is None: raise ValueError("runtime is required") - return ModelInvocationUtils.calculate_tokens( tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content @@ -102,7 +95,6 @@ class BuiltinTool(Tool): prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], stop=[], ) - assert isinstance(summary.message.content, str) return summary.message.content @@ -122,7 +114,6 @@ class BuiltinTool(Tool): new_lines.append(line) else: new_lines.append(line) - # merge lines into messages with max tokens messages: list[str] = [] for j in new_lines: @@ -135,16 +126,12 @@ class BuiltinTool(Tool): messages.append(j) else: messages[-1] += j - summaries = [] for i in range(len(messages)): message = messages[i] summary = summarize(message) summaries.append(summary) - result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - return result diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 3137d32013..d49481c545 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -76,10 +76,8 @@ class ApiToolProviderController(ToolProviderController): ] elif auth_type == ApiProviderAuthType.NONE: pass - user = db_provider.user user_name = user.name if user else "" - return ApiToolProviderController( entity=ToolProviderEntity( identity=ToolProviderIdentity( @@ -103,7 +101,6 @@ class ApiToolProviderController(ToolProviderController): def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ parse tool bundle to tool - :param tool_bundle: the tool bundle :return: the tool """ @@ -133,54 +130,44 @@ class ApiToolProviderController(ToolProviderController): def load_bundled_tools(self, tools: list[ApiToolBundle]): """ load bundled tools - :param tools: the bundled tools :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] - return self.tools def get_tools(self, tenant_id: str) -> list[ApiTool]: """ fetch tools from database - :param tenant_id: the tenant id :return: the tools """ if len(self.tools) > 0: return self.tools - tools: list[ApiTool] = [] - # get tenant api providers db_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .all() ) - if db_providers and len(db_providers) != 0: for db_provider in db_providers: for tool in db_provider.tools: assistant_tool = self._parse_tool_bundle(tool) tools.append(assistant_tool) - self.tools = tools return tools def get_tool(self, tool_name: str): """ get tool by name - :param tool_name: the name of the tool :return: the tool """ if self.tools is None: self.get_tools(self.tenant_id) - for tool in self.tools: if tool.entity.identity.name == tool_name: return tool - raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 2f5cc6d4c0..fd62722d30 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -23,7 +23,6 @@ API_TOOL_DEFAULT_TIMEOUT = ( class ApiTool(Tool): api_bundle: ApiToolBundle provider_id: str - """ Api tool """ @@ -55,10 +54,8 @@ class ApiTool(Tool): """ # assemble validate request and request parameters headers = self.assembling_request(parameters) - if format_only: return "" - response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response return self.validate_and_parse_response(response) @@ -69,26 +66,20 @@ class ApiTool(Tool): def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: if self.runtime is None: raise ToolProviderCredentialValidationError("runtime not initialized") - headers = {} if self.runtime is None: raise ValueError("runtime is required") credentials = self.runtime.credentials or {} - if "auth_type" not in credentials: raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials["auth_type"] == "api_key": api_key_header = "api_key" - if "api_key_header" in credentials: api_key_header = credentials["api_key_header"] - if "api_key_value" not in credentials: raise ToolProviderCredentialValidationError("Missing api_key_value") elif not isinstance(credentials["api_key_value"], str): raise ToolProviderCredentialValidationError("api_key_value must be a string") - if "api_key_header_prefix" in credentials: api_key_header_prefix = credentials["api_key_header_prefix"] if api_key_header_prefix == "basic" and credentials["api_key_value"]: @@ -97,9 +88,7 @@ class ApiTool(Tool): credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}" elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials["api_key_value"] - needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: @@ -107,7 +96,6 @@ class ApiTool(Tool): parameters[parameter.name] = parameter.default else: raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") - return headers def validate_and_parse_response(self, response: httpx.Response) -> str: @@ -146,30 +134,24 @@ class ApiTool(Tool): do http request depending on api bundle """ method = method.lower() - params = {} path_params = {} # FIXME: body should be a dict[str, Any] but it changed a lot in this function body: Any = {} cookies = {} files = [] - # check parameters for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) if parameter["in"] == "path": path_params[parameter["name"]] = value - elif parameter["in"] == "query": if value != "": params[parameter["name"]] = value - elif parameter["in"] == "cookie": cookies[parameter["name"]] = value - elif parameter["in"] == "header": headers[parameter["name"]] = str(value) - # check if there is a request body and handle it if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body @@ -177,7 +159,6 @@ class ApiTool(Tool): for content_type in self.api_bundle.openapi["requestBody"]["content"]: headers["Content-Type"] = content_type body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] - # handle ref schema if "$ref" in body_schema: ref_path = body_schema["$ref"].split("/") @@ -188,7 +169,6 @@ class ApiTool(Tool): ): if ref_name in self.api_bundle.openapi["components"]["schemas"]: body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name] - required = body_schema.get("required", []) properties = body_schema.get("properties", {}) for name, property in properties.items(): @@ -215,11 +195,9 @@ class ApiTool(Tool): else: body[name] = None break - # replace path parameters for name, value in path_params.items(): url = url.replace(f"{{{name}}}", f"{value}") - # parse http body data if needed if "Content-Type" in headers: if headers["Content-Type"] == "application/json": @@ -228,14 +206,12 @@ class ApiTool(Tool): body = urlencode(body) else: body = body - # if there is a file upload, remove the Content-Type header # so that httpx can automatically generate the boundary header required for multipart/form-data. # issue: https://github.com/langgenius/dify/issues/13684 # reference: https://stackoverflow.com/questions/39280438/fetch-missing-boundary-in-multipart-form-data-post if files: headers.pop("Content-Type", None) - if method in { "get", "head", @@ -352,12 +328,9 @@ class ApiTool(Tool): response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) - # do http request response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) - # validate response response = self.validate_and_parse_response(response) - # assemble invoke message yield self.create_text_message(response) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b96c994cff..369842cf0c 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -53,7 +53,6 @@ class ToolProviderApiEntity(BaseModel): if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" # ------------- - return { "id": self.id, "author": self.author, diff --git a/api/core/tools/entities/file_entities.py b/api/core/tools/entities/file_entities.py index 8b13789179..e69de29bb2 100644 --- a/api/core/tools/entities/file_entities.py +++ b/api/core/tools/entities/file_entities.py @@ -1 +0,0 @@ - diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d2c28076ae..36d231fa71 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -54,7 +54,6 @@ class ToolProviderType(enum.StrEnum): def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -78,7 +77,6 @@ class ApiProviderSchemaType(Enum): def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -100,7 +98,6 @@ class ApiProviderAuthType(Enum): def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -144,12 +141,10 @@ class ToolInvokeMessage(BaseModel): value = values.get("variable_value") if not isinstance(value, dict | list | str | int | float | bool): raise ValueError("Only basic types and lists are allowed.") - # if stream is true, the value must be a string if values.get("stream"): if not isinstance(value, str): raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - return values @field_validator("variable_name", mode="before") @@ -241,7 +236,6 @@ class ToolParameter(PluginParameter): APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value - # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value @@ -272,7 +266,6 @@ class ToolParameter(PluginParameter): ) -> "ToolParameter": """ get a simple tool parameter - :param name: the name of the parameter :param llm_description: the description presented to the LLM :param typ: the type of the parameter @@ -287,7 +280,6 @@ class ToolParameter(PluginParameter): ] else: option_objs = [] - return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), @@ -335,7 +327,6 @@ class ToolEntity(BaseModel): description: Optional[ToolDescription] = None output_schema: Optional[dict] = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") - # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index f460df7e25..3832906713 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -51,7 +51,6 @@ ICONS = { """, # noqa: E501 } - default_tool_label_dict = { ToolLabelEnum.SEARCH: ToolLabel( name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] @@ -106,6 +105,5 @@ default_tool_label_dict = { name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] ), } - default_tool_labels = [v for k, v in default_tool_label_dict.items()] default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 494b8e209c..c60af7ee5d 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -26,7 +26,6 @@ class PluginToolProviderController(BuiltinToolProviderController): def provider_type(self) -> ToolProviderType: """ returns the type of the provider - :return: type of the provider """ return ToolProviderType.PLUGIN @@ -51,10 +50,8 @@ class PluginToolProviderController(BuiltinToolProviderController): tool_entity = next( (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None ) - if not tool_entity: raise ValueError(f"Tool with name {tool_name} not found") - return PluginTool( entity=tool_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..66a896ea5d 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -35,9 +35,7 @@ class PluginTool(Tool): message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() - tool_parameters = convert_parameters_to_plugin_format(tool_parameters) - yield from manager.invoke( tenant_id=self.tenant_id, user_id=user_id, @@ -70,10 +68,8 @@ class PluginTool(Tool): """ if not self.entity.has_runtime_parameters: return self.entity.parameters - if self.runtime_parameters is not None: return self.runtime_parameters - manager = PluginToolManager() self.runtime_parameters = manager.get_runtime_parameters( tenant_id=self.tenant_id, @@ -85,5 +81,4 @@ class PluginTool(Tool): app_id=app_id, message_id=message_id, ) - return self.runtime_parameters diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index e80005d7bf..bb7631e55d 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -13,14 +13,12 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: """ base_url = dify_config.FILES_URL file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" - timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @@ -32,10 +30,8 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - # verify signature if sign != recalculated_encoded_sign: return False - current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b9689..acdd25c5ad 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -75,11 +75,9 @@ class ToolEngine: pass if not isinstance(tool_parameters, dict): raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") - try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) - messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) invocation_meta_dict: dict[str, ToolInvokeMeta] = {} @@ -98,20 +96,15 @@ class ToolEngine: tenant_id=tenant_id, conversation_id=message.conversation_id, ) - message_list = list(messages) - # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list) # create message file message_files = ToolEngine._create_message_files( tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) - plain_text = ToolEngine._convert_tool_response_to_str(message_list) - meta = invocation_meta_dict["meta"] - # hit the callback handler agent_tool_callback.on_tool_end( tool_name=tool.entity.identity.name, @@ -120,7 +113,6 @@ class ToolEngine: message_id=message.id, trace_manager=trace_manager, ) - # transform tool invoke message to get LLM friendly message return plain_text, message_files, meta except ToolProviderCredentialValidationError as e: @@ -143,7 +135,6 @@ class ToolEngine: except Exception as e: error_response = f"unknown error: {e}" agent_tool_callback.on_tool_error(e) - return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod @@ -164,14 +155,11 @@ class ToolEngine: try: # hit the callback handler workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) - if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 tool.thread_pool_id = thread_pool_id - if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} - response = tool.invoke( user_id=user_id, tool_parameters=tool_parameters, @@ -179,14 +167,12 @@ class ToolEngine: app_id=app_id, message_id=message_id, ) - # hit the callback handler response = workflow_tool_callback.on_tool_execution( tool_name=tool.entity.identity.name, tool_inputs=tool_parameters, tool_outputs=response, ) - return response except Exception as e: workflow_tool_callback.on_tool_error(e) @@ -251,7 +237,6 @@ class ToolEngine: ) else: result += str(response.message) - return result @staticmethod @@ -277,10 +262,8 @@ class ToolEngine: mimetype = guess_type_result except Exception: pass - if not mimetype: mimetype = "image/jpeg" - yield ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, @@ -288,7 +271,6 @@ class ToolEngine: elif response.type == ToolInvokeMessage.MessageType.BLOB: if not response.meta: raise ValueError("missing meta data") - yield ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "application/octet-stream"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, @@ -312,11 +294,9 @@ class ToolEngine: ) -> list[str]: """ Create message file - :return: message file ids """ result = [] - for message in tool_messages: if "image" in message.mimetype: file_type = FileType.IMAGE @@ -328,7 +308,6 @@ class ToolEngine: file_type = FileType.DOCUMENT else: file_type = FileType.CUSTOM - # extract tool file id from url tool_file_id = message.url.split("/")[-1].split(".")[0] message_file = MessageFile( @@ -345,13 +324,9 @@ class ToolEngine: ), created_by=user_id, ) - db.session.add(message_file) db.session.commit() db.session.refresh(message_file) - result.append(message_file.id) - db.session.close() - return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b849f51064..ecb65f8cff 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -20,7 +20,6 @@ from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) - from sqlalchemy.engine import Engine @@ -39,14 +38,12 @@ class ToolFileManager: """ base_url = dify_config.FILES_URL file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" - timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod @@ -58,11 +55,9 @@ class ToolFileManager: secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - # verify signature if sign != recalculated_encoded_sign: return False - current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT @@ -87,7 +82,6 @@ class ToolFileManager: present_filename = filename if has_extension else f"{filename}{extension}" filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - with Session(self._engine, expire_on_commit=False) as session: tool_file = ToolFile( user_id=user_id, @@ -98,11 +92,9 @@ class ToolFileManager: name=present_filename, size=len(file_binary), ) - session.add(tool_file) session.commit() session.refresh(tool_file) - return tool_file def create_file_by_url( @@ -119,7 +111,6 @@ class ToolFileManager: blob = response.content except httpx.TimeoutException: raise ValueError(f"timeout when downloading file from {file_url}") - mimetype = ( guess_type(file_url)[0] or response.headers.get("Content-Type", "").split(";")[0].strip() @@ -130,7 +121,6 @@ class ToolFileManager: filename = f"{unique_name}{extension}" filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - with Session(self._engine, expire_on_commit=False) as session: tool_file = ToolFile( user_id=user_id, @@ -142,18 +132,14 @@ class ToolFileManager: name=filename, size=len(blob), ) - session.add(tool_file) session.commit() - return tool_file def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary - :param id: the id of the file - :return: the binary of the file, mime type """ with Session(self._engine, expire_on_commit=False) as session: @@ -164,20 +150,15 @@ class ToolFileManager: ) .first() ) - if not tool_file: return None - blob = storage.load_once(tool_file.file_key) - return blob, tool_file.mimetype def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary - :param id: the id of the file - :return: the binary of the file, mime type """ with Session(self._engine, expire_on_commit=False) as session: @@ -188,7 +169,6 @@ class ToolFileManager: ) .first() ) - # Check if message_file is not None if message_file is not None: # get tool file id @@ -200,7 +180,6 @@ class ToolFileManager: tool_file_id = None else: tool_file_id = None - tool_file: ToolFile | None = ( session.query(ToolFile) .filter( @@ -208,20 +187,15 @@ class ToolFileManager: ) .first() ) - if not tool_file: return None - blob = storage.load_once(tool_file.file_key) - return blob, tool_file.mimetype def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]: """ get file binary - :param tool_file_id: the id of the tool file - :return: the binary of the file, mime type """ with Session(self._engine, expire_on_commit=False) as session: @@ -232,12 +206,9 @@ class ToolFileManager: ) .first() ) - if not tool_file: return None, None - stream = storage.load_stream(tool_file.file_key) - return stream, tool_file diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 4787d7d79c..798deb1a68 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -22,15 +22,12 @@ class ToolLabelManager: Update tool labels """ labels = cls.filter_tool_labels(labels) - if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: raise ValueError("Unsupported tool type") - # delete old labels db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() - # insert new labels for label in labels: db.session.add( @@ -40,7 +37,6 @@ class ToolLabelManager: label_name=label, ) ) - db.session.commit() @classmethod @@ -54,7 +50,6 @@ class ToolLabelManager: return controller.tool_labels else: raise ValueError("Unsupported tool type") - labels = ( db.session.query(ToolLabelBinding.label_name) .filter( @@ -63,39 +58,30 @@ class ToolLabelManager: ) .all() ) - return [label.label_name for label in labels] @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: """ Get tools labels - :param tool_providers: list of tool providers - :return: dict of tool labels :key: tool id :value: list of tool labels """ if not tool_providers: return {} - for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - provider_ids = [] for controller in tool_providers: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) - labels: list[ToolLabelBinding] = ( db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} - for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0bfe6329b1..9adef90165 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -19,8 +19,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -69,7 +67,6 @@ class ToolManager: if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() - return cls._hardcoded_providers[provider] @classmethod @@ -78,23 +75,19 @@ class ToolManager: ) -> BuiltinToolProviderController | PluginToolProviderController: """ get the builtin provider - :param provider: the name of the provider :param tenant_id: the id of the tenant :return: the provider """ # split provider to - if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() - if provider not in cls._hardcoded_providers: # get plugin provider plugin_provider = cls.get_plugin_provider(provider, tenant_id) if plugin_provider: return plugin_provider - return cls._hardcoded_providers[provider] @classmethod @@ -108,33 +101,27 @@ class ToolManager: except LookupError: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) - with contexts.plugin_tool_providers_lock.get(): plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] - manager = PluginToolManager() provider_entity = manager.fetch_tool_provider(tenant_id, provider) if not provider_entity: raise ToolProviderNotFoundError(f"plugin provider {provider} not found") - controller = PluginToolProviderController( entity=provider_entity.declaration, plugin_id=provider_entity.plugin_id, plugin_unique_identifier=provider_entity.plugin_unique_identifier, tenant_id=tenant_id, ) - plugin_tool_providers[provider] = controller - return controller @classmethod def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: """ get the builtin tool - :param provider: the name of the provider :param tool_name: the name of the tool :param tenant_id: the id of the tenant @@ -144,7 +131,6 @@ class ToolManager: tool = provider_controller.get_tool(tool_name) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool @classmethod @@ -159,24 +145,20 @@ class ToolManager: ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: """ get the tool runtime - :param provider_type: the type of the provider :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from - :return: the tool """ if provider_type == ToolProviderType.BUILT_IN: # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - builtin_tool = provider_controller.get_tool(tool_name) if not builtin_tool: raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") - if not provider_controller.need_credentials: return cast( BuiltinTool, @@ -189,7 +171,6 @@ class ToolManager: ) ), ) - if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) # get credentials @@ -202,7 +183,6 @@ class ToolManager: ) .first() ) - if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") else: @@ -211,10 +191,8 @@ class ToolManager: .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .first() ) - if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials credentials = builtin_provider.credentials tool_configuration = ProviderConfigEncrypter( @@ -223,9 +201,7 @@ class ToolManager: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( @@ -238,10 +214,8 @@ class ToolManager: ) ), ) - elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - # decrypt the credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, @@ -250,7 +224,6 @@ class ToolManager: provider_identity=api_provider.entity.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( @@ -268,15 +241,12 @@ class ToolManager: .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) - if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) if controller_tools is None or len(controller_tools) == 0: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return cast( WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( @@ -328,12 +298,10 @@ class ToolManager: and parameter.required ): raise ValueError(f"file type parameter {parameter.name} not supported in agent") - if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) runtime_parameters[parameter.name] = value - # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( tenant_id=tenant_id, @@ -345,7 +313,6 @@ class ToolManager: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: raise ValueError("runtime not found or runtime parameters not found") - tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -371,13 +338,11 @@ class ToolManager: ) runtime_parameters = {} parameters = tool_runtime.get_merged_runtime_parameters() - for parameter in parameters: # save tool parameter to tool entity memory if parameter.form == ToolParameter.ToolParameterForm.FORM: value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) runtime_parameters[parameter.name] = value - # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( tenant_id=tenant_id, @@ -386,10 +351,8 @@ class ToolManager: provider_type=workflow_tool.provider_type, identity_id=f"WORKFLOW.{app_id}.{node_id}", ) - if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - tool_runtime.runtime.runtime_parameters.update(runtime_parameters) return tool_runtime @@ -420,7 +383,6 @@ class ToolManager: # save tool parameter to tool entity memory value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) runtime_parameters[parameter.name] = value - tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -428,13 +390,11 @@ class ToolManager: def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]: """ get the absolute path of the icon of the hardcoded provider - :param provider: the name of the provider :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_hardcoded_provider(provider) - absolute_path = path.join( path.dirname(path.realpath(__file__)), "builtin_tool", @@ -446,11 +406,9 @@ class ToolManager: # check if the icon exists if not path.exists(absolute_path): raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") - # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) mime_type = mime_type or "application/octet-stream" - return absolute_path, mime_type @classmethod @@ -459,12 +417,10 @@ class ToolManager: if cls._builtin_providers_loaded: yield from list(cls._hardcoded_providers.values()) return - with cls._builtin_provider_lock: if cls._builtin_providers_loaded: yield from list(cls._hardcoded_providers.values()) return - yield from cls._list_hardcoded_providers() @classmethod @@ -503,11 +459,9 @@ class ToolManager: for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")): if provider_path.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)): if provider_path.startswith("__"): continue - # init provider try: provider_class = load_single_subclass_from_source( @@ -526,7 +480,6 @@ class ToolManager: for tool in provider.get_tools(): cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label yield provider - except Exception: logger.exception(f"load builtin provider {provider_path}") continue @@ -547,18 +500,14 @@ class ToolManager: def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ get the tool label - :param tool_name: the name of the tool - :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() - if tool_name not in cls._builtin_tools_labels: return None - return cls._builtin_tools_labels[tool_name] @classmethod @@ -566,23 +515,19 @@ class ToolManager: cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral ) -> list[ToolProviderApiEntity]: result_providers: dict[str, ToolProviderApiEntity] = {} - filters = [] if not typ: filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - with db.session.no_autoflush: if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() ) - # rewrite db_builtin_providers for db_provider in db_builtin_providers: tool_provider_id = str(ToolProviderID(db_provider.provider)) @@ -601,33 +546,26 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.entity.identity.name), decrypt_credentials=False, ) - if isinstance(provider, PluginToolProviderController): result_providers[f"plugin_provider.{user_provider.name}"] = user_provider else: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider - # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] - # get labels labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( provider_controller=api_provider_controller["controller"], @@ -636,13 +574,11 @@ class ToolManager: labels=labels.get(api_provider_controller["controller"].provider_id, []), ) result_providers[f"api_provider.{user_provider.name}"] = user_provider - if "workflow" in filters: # get workflow providers workflow_providers: list[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: try: @@ -652,18 +588,15 @@ class ToolManager: except Exception: # app has been deleted pass - labels = ToolLabelManager.get_tools_labels( [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] ) - for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) result_providers[f"workflow_provider.{user_provider.name}"] = user_provider - return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod @@ -672,10 +605,8 @@ class ToolManager: ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ get the api provider - :param tenant_id: the id of the tenant :param provider_id: the id of the provider - :return: the provider controller, the credentials """ provider: ApiToolProvider | None = ( @@ -686,16 +617,13 @@ class ToolManager: ) .first() ) - if provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - controller = ApiToolProviderController.from_db( provider, ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) - return controller, provider.credentials @classmethod @@ -715,15 +643,12 @@ class ToolManager: ) .first() ) - if provider_obj is None: raise ValueError(f"you have not added provider {provider_name}") - try: credentials = json.loads(provider_obj.credentials_str) or {} except Exception: credentials = {} - # package tool provider controller controller = ApiToolProviderController.from_db( provider_obj, @@ -736,18 +661,14 @@ class ToolManager: provider_type=controller.provider_type.value, provider_identity=controller.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) - try: icon = json.loads(provider_obj.icon) except Exception: icon = {"background": "#252525", "content": "\ud83d\ude01"} - # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return cast( dict, jsonable_encoder( @@ -800,10 +721,8 @@ class ToolManager: .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) - if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - icon: dict = json.loads(workflow_provider.icon) return icon except Exception: @@ -817,10 +736,8 @@ class ToolManager: .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) - if api_provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - icon: dict = json.loads(api_provider.icon) return icon except Exception: @@ -835,7 +752,6 @@ class ToolManager: ) -> Union[str, dict]: """ get the tool icon - :param tenant_id: the id of the tenant :param provider_type: the type of the provider :param provider_id: the id of the provider diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 1f23e90351..3752f29402 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -29,37 +29,30 @@ class ProviderConfigEncrypter(BaseModel): def encrypt(self, data: dict[str, str]) -> dict[str, str]: """ encrypt tool credentials with tenant id - return a deep copy of credentials with encrypted values """ data = self._deep_copy(data) - # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() for credential in self.config: fields[credential.name] = credential - for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") data[field_name] = encrypted - return data def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: """ mask tool credentials - return a deep copy of credentials with masked values """ data = self._deep_copy(data) - # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() for credential in self.config: fields[credential.name] = credential - for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: @@ -69,13 +62,11 @@ class ProviderConfigEncrypter(BaseModel): ) else: data[field_name] = "*" * len(data[field_name]) - return data def decrypt(self, data: dict[str, str]) -> dict[str, str]: """ decrypt tool credentials with tenant id - return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( @@ -86,13 +77,11 @@ class ProviderConfigEncrypter(BaseModel): cached_credentials = cache.get() if cached_credentials: return cached_credentials - data = self._deep_copy(data) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() for credential in self.config: fields[credential.name] = credential - for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: @@ -100,11 +89,9 @@ class ProviderConfigEncrypter(BaseModel): # if the value is None or empty string, skip decrypt if not data[field_name]: continue - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) except Exception: pass - cache.set(data) return data @@ -160,23 +147,18 @@ class ToolParameterConfigurationManager: current_parameters[index] = runtime_parameter found = True break - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) - return current_parameters def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ mask tool parameters - return a deep copy of parameters with masked values """ parameters = self._deep_copy(parameters) - # override parameters current_parameters = self._merge_parameters() - for parameter in current_parameters: if ( parameter.form == ToolParameter.ToolParameterForm.FORM @@ -191,20 +173,16 @@ class ToolParameterConfigurationManager: ) else: parameters[parameter.name] = "*" * len(parameters[parameter.name]) - return parameters def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ encrypt tool parameters with tenant id - return a deep copy of parameters with encrypted values """ # override parameters current_parameters = self._merge_parameters() - parameters = self._deep_copy(parameters) - for parameter in current_parameters: if ( parameter.form == ToolParameter.ToolParameterForm.FORM @@ -213,16 +191,13 @@ class ToolParameterConfigurationManager: if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted - return parameters def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ decrypt tool parameters with tenant id - return a deep copy of parameters with decrypted values """ - cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type.value}.{self.provider_name}", @@ -233,11 +208,9 @@ class ToolParameterConfigurationManager: cached_parameters = cache.get() if cached_parameters: return cached_parameters - # override parameters current_parameters = self._merge_parameters() has_secret_input = False - for parameter in current_parameters: if ( parameter.form == ToolParameter.ToolParameterForm.FORM @@ -249,10 +222,8 @@ class ToolParameterConfigurationManager: parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) except Exception: pass - if has_secret_input: cache.set(parameters) - return parameters def delete_tool_parameters_cache(self): diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 2cbc4b9821..093b4dde20 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -71,18 +71,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): model_type=ModelType.RERANK, model=self.reranking_model_name, ) - rerank_runner = RerankModelRunner(rerank_model_instance) all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) - for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(all_documents) - document_score_list = {} for item in all_documents: if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = ( @@ -96,7 +92,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): ) .all() ) - if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} sorted_segments = sorted( @@ -134,7 +129,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): score=document_score_list.get(segment.index_node_id, None), doc_metadata=document.doc_metadata, ) - if self.retriever_from == "dev": source.hit_count = segment.hit_count source.word_count = segment.word_count @@ -146,10 +140,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): source.content = segment.content context_list.append(source) resource_number += 1 - for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) return "" @@ -165,16 +157,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset = ( db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() ) - if not dataset: return [] - for hit_callback in hit_callbacks: hit_callback.on_query(query, dataset.id) - # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -202,5 +190,4 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) - all_documents.extend(documents) diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index a4d2de3b1c..92c7b3705f 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -27,7 +27,6 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): **kwargs: Any, ) -> Any: """Use the tool. - Add run_manager: Optional[CallbackManagerForToolRun] = None to child implementations to enable tracing, """ diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index ff1d9021ce..351206e3d6 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -45,7 +45,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): description = dataset.description if not description: description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", @@ -59,7 +58,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset = ( db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() ) - if not dataset: return "" for hit_callback in self.hit_callbacks: @@ -119,7 +117,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join([item.page_content for item in results])) else: if metadata_condition and not document_ids_filter: @@ -183,7 +180,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score=record.score, ) ) - if self.return_resource: for record in records: segment = record.segment @@ -209,7 +205,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score=record.score or 0.0, doc_metadata=document.doc_metadata, # type: ignore ) - if self.retriever_from == "dev": source.hit_count = segment.hit_count source.word_count = segment.word_count @@ -220,7 +215,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else: source.content = segment.content retrieval_resource_list.append(source) - if self.return_resource and retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index ec0575f6c3..637ec9810d 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -45,9 +45,7 @@ class DatasetRetrieverTool(Tool): return [] if retrieve_config is None: return [] - feature = DatasetRetrieval() - # save original retrieve strategy, and set retrieve strategy to SINGLE # Agent only support SINGLE mode original_retriever_mode = retrieve_config.retrieve_strategy @@ -64,10 +62,8 @@ class DatasetRetrieverTool(Tool): ) if retrieval_tools is None or len(retrieval_tools) == 0: return [] - # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - # convert retrieval tools to Tools tools = [] for retrieval_tool in retrieval_tools: @@ -82,9 +78,7 @@ class DatasetRetrieverTool(Tool): ), runtime=ToolRuntime(tenant_id=tenant_id), ) - tools.append(tool) - return tools def get_runtime_parameters( diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de0465..7df6ae125b 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -38,9 +38,7 @@ class ToolFileMessageTransformer: file_url=message.message.text, conversation_id=conversation_id, ) - url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" - yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), @@ -57,15 +55,12 @@ class ToolFileMessageTransformer: elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage meta = message.meta or {} - mimetype = meta.get("mime_type", "application/octet-stream") # get filename from meta filename = meta.get("filename", None) # if message is str, encode it to bytes - if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") - assert isinstance(message.message.blob, bytes) tool_file_manager = ToolFileManager() tool_file = tool_file_manager.create_file_by_raw( @@ -76,9 +71,7 @@ class ToolFileMessageTransformer: mimetype=mimetype, filename=filename, ) - url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) - # check if file is image if "image" in mimetype: yield ToolInvokeMessage( diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 3f59b3f472..4d9636d83e 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,6 +1,5 @@ """ For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. """ @@ -41,20 +40,15 @@ class ModelInvocationUtils: tenant_id=tenant_id, model_type=ModelType.LLM, ) - if not model_instance: raise InvokeModelError("Model not found") - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - if not schema: raise InvokeModelError("No model schema found") - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - return max_tokens @staticmethod @@ -62,17 +56,13 @@ class ModelInvocationUtils: """ calculate tokens from prompt messages and model parameters """ - # get model instance model_manager = ModelManager() model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) - if not model_instance: raise InvokeModelError("Model not found") - # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) - return tokens @staticmethod @@ -81,7 +71,6 @@ class ModelInvocationUtils: ) -> LLMResult: """ invoke model with parameters in user's own context - :param user_id: user id :param tenant_id: tenant id, the tenant id of the creator of the tool :param tool_type: tool type @@ -89,7 +78,6 @@ class ModelInvocationUtils: :param prompt_messages: prompt messages :return: AssistantPromptMessage """ - # get model manager model_manager = ModelManager() # get model instance @@ -97,15 +85,12 @@ class ModelInvocationUtils: tenant_id=tenant_id, model_type=ModelType.LLM, ) - # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - model_parameters = { "temperature": 0.8, "top_p": 0.8, } - # create tool model invoke tool_model_invoke = ToolModelInvoke( user_id=user_id, @@ -124,10 +109,8 @@ class ModelInvocationUtils: total_price=0, currency="USD", ) - db.session.add(tool_model_invoke) db.session.commit() - try: response: LLMResult = cast( LLMResult, @@ -153,7 +136,6 @@ class ModelInvocationUtils: raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: raise InvokeModelError(f"Invoke error: {e}") - # update tool model invoke tool_model_invoke.model_response = response.message.content if response.usage: @@ -163,7 +145,5 @@ class ModelInvocationUtils: tool_model_invoke.provider_response_latency = response.usage.latency tool_model_invoke.total_price = response.usage.total_price tool_model_invoke.currency = response.usage.currency - db.session.commit() - return response diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3f844e8234..c0b1fd14e3 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -22,19 +22,15 @@ class ApiBasedToolSchemaParser: ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - # set description to extra_info extra_info["description"] = openapi["info"].get("description", "") - if len(openapi["servers"]) == 0: raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi["servers"][0]["url"] request_env = request.headers.get("X-Request-Env") if request_env: matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] server_url = matched_servers[0] if matched_servers else server_url - # list all interfaces interfaces = [] for path, path_item in openapi["paths"].items(): @@ -48,7 +44,6 @@ class ApiBasedToolSchemaParser: "operation": path_item[method], } ) - # get all parameters bundles = [] for interface in interfaces: @@ -80,12 +75,10 @@ class ApiBasedToolSchemaParser: en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), ) - # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: tool_parameter.type = typ - parameters.append(tool_parameter) # create tool bundle # check if there is a request body @@ -96,7 +89,6 @@ class ApiBasedToolSchemaParser: # if there is a reference, get the reference and overwrite the content if "schema" not in content: continue - if "$ref" in content["schema"]: # get the reference root = openapi @@ -105,7 +97,6 @@ class ApiBasedToolSchemaParser: root = root[ref] # overwrite the content interface["operation"]["requestBody"]["content"][content_type]["schema"] = root - # parse body parameters if "schema" in interface["operation"]["requestBody"]["content"][content_type]: body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] @@ -127,14 +118,11 @@ class ApiBasedToolSchemaParser: en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), ) - # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) if typ: tool.type = typ - parameters.append(tool) - # check if parameters is duplicated parameters_count = {} for parameter in parameters: @@ -144,7 +132,6 @@ class ApiBasedToolSchemaParser: for name, count in parameters_count.items(): if count > 1: warning["duplicated_parameter"] = f"Parameter {name} is duplicated." - # check if there is a operation id, use $path_$method as operation id if not if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ @@ -155,9 +142,7 @@ class ApiBasedToolSchemaParser: path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - interface["operation"]["operationId"] = f"{path}_{interface['method']}" - bundles.append( ApiToolBundle( server_url=server_url + interface["path"], @@ -172,7 +157,6 @@ class ApiBasedToolSchemaParser: openapi=interface["operation"], ) ) - return bundles @staticmethod @@ -181,12 +165,10 @@ class ApiBasedToolSchemaParser: typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE - if "type" in parameter: typ = parameter["type"] elif "schema" in parameter and "type" in parameter["schema"]: typ = parameter["schema"]["type"] - if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER elif typ == "boolean": @@ -205,7 +187,6 @@ class ApiBasedToolSchemaParser: ) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle - :param yaml: the yaml string :param extra_info: the extra info :param warning: the warning message @@ -213,7 +194,6 @@ class ApiBasedToolSchemaParser: """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - openapi: dict = safe_load(yaml) if openapi is None: raise ToolApiSchemaError("Invalid openapi yaml.") @@ -224,18 +204,14 @@ class ApiBasedToolSchemaParser: warning = warning or {} """ parse swagger to openapi - :param swagger: the swagger dict :return: the openapi dict """ # convert swagger to openapi info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get("servers", []) - if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - openapi = { "openapi": "3.0.0", "info": { @@ -247,24 +223,20 @@ class ApiBasedToolSchemaParser: "paths": {}, "components": {"schemas": {}}, } - # check paths if "paths" not in swagger or len(swagger["paths"]) == 0: raise ToolApiSchemaError("No paths found in the swagger yaml.") - # convert paths for path, path_item in swagger["paths"].items(): openapi["paths"][path] = {} for method, operation in path_item.items(): if "operationId" not in operation: raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") - if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): if warning is not None: warning["missing_summary"] = f"No summary or description found in operation {method} {path}." - openapi["paths"][path][method] = { "operationId": operation["operationId"], "summary": operation.get("summary", ""), @@ -272,14 +244,11 @@ class ApiBasedToolSchemaParser: "parameters": operation.get("parameters", []), "responses": operation.get("responses", {}), } - if "requestBody" in operation: openapi["paths"][path][method]["requestBody"] = operation["requestBody"] - # convert definitions for name, definition in swagger["definitions"].items(): openapi["components"]["schemas"][name] = definition - return openapi @staticmethod @@ -288,7 +257,6 @@ class ApiBasedToolSchemaParser: ) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle - :param json: the json string :param extra_info: the extra info :param warning: the warning message @@ -296,7 +264,6 @@ class ApiBasedToolSchemaParser: """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - try: openai_plugin = json_loads(json) api = openai_plugin["api"] @@ -304,16 +271,12 @@ class ApiBasedToolSchemaParser: api_type = api["type"] except JSONDecodeError: raise ToolProviderNotFoundError("Invalid openai plugin json.") - if api_type != "openapi": raise ToolNotSupportedError("Only openapi is supported now.") - # get openapi yaml response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) - if response.status_code != 200: raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( response.text, extra_info=extra_info, warning=warning ) @@ -324,7 +287,6 @@ class ApiBasedToolSchemaParser: ) -> tuple[list[ApiToolBundle], str]: """ auto parse to tool bundle - :param content: the content :param extra_info: the extra info :param warning: the warning message @@ -332,17 +294,14 @@ class ApiBasedToolSchemaParser: """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - content = content.strip() loaded_content = None json_error = None yaml_error = None - try: loaded_content = json_loads(content) except JSONDecodeError as e: json_error = e - if loaded_content is None: try: loaded_content = safe_load(content) @@ -353,12 +312,10 @@ class ApiBasedToolSchemaParser: f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," f" yaml error: {str(yaml_error)}" ) - swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - try: openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( loaded_content, extra_info=extra_info, warning=warning @@ -367,7 +324,6 @@ class ApiBasedToolSchemaParser: return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( @@ -379,7 +335,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( @@ -389,7 +344,6 @@ class ApiBasedToolSchemaParser: except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError( f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," f" openapi plugin error: {str(openapi_plugin_error)}" diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..1a5e583940 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -4,10 +4,8 @@ import re def remove_leading_symbols(text: str) -> str: """ Remove leading punctuation or symbols from the given text. - Args: text (str): The input text to process. - Returns: str: The text with leading punctuation or symbols removed. """ diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index cbd06fc186..5f07e2e7e7 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -17,7 +17,6 @@ FULL_TEMPLATE = """ TITLE: {title} AUTHOR: {author} TEXT: - {text} """ @@ -35,11 +34,9 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: } if user_agent: headers["User-Agent"] = user_agent - main_content_type = None supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) - if response.status_code == 200: # check content-type content_type = response.headers.get("Content-Type") @@ -53,22 +50,17 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] - if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) - if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) - response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() scraper.perform_request = ssrf_proxy.make_request # type: ignore response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore - if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) - # Detect encoding using chardet detected_encoding = chardet.detect(response.content) encoding = detected_encoding["encoding"] @@ -79,18 +71,14 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: content = response.text else: content = response.text - article = extract_using_readabilipy(content) - if not article.text: return "" - res = FULL_TEMPLATE.format( title=article.title, author=article.auther, text=article.text, ) - return res @@ -108,7 +96,6 @@ def extract_using_readabilipy(html: str): auther=json_article.get("byline") or "", text=json_article.get("plain_text") or [], ) - return article diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d16d6fc576..ce6933768b 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -18,10 +18,8 @@ class WorkflowToolConfigurationUtils: """ nodes = graph.get("nodes", []) start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) - if not start_node: return [] - return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] @classmethod @@ -30,14 +28,11 @@ class WorkflowToolConfigurationUtils: ): """ check is synced - raise ValueError if not synced """ variable_names = [variable.variable for variable in variables] - if len(tool_configurations) != len(variables): raise ValueError("parameter configuration mismatch, please republish the tool to update") - for parameter in tool_configurations: if parameter.name not in variable_names: raise ValueError("parameter configuration mismatch, please republish the tool to update") diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index ee7ca11e05..76280d1745 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -23,7 +23,6 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any return default_value else: raise FileNotFoundError(f"File not found: {file_path}") - with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 7661e1e6a5..ef735207a6 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -46,10 +46,8 @@ class WorkflowToolProviderController(ToolProviderController): @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app - if not app: raise ValueError("app not found") - controller = WorkflowToolProviderController( entity=ToolProviderEntity( identity=ToolProviderIdentity( @@ -64,11 +62,8 @@ class WorkflowToolProviderController(ToolProviderController): ), provider_id=db_provider.id or "", ) - # init tools - controller.tools = [controller._get_db_provider_tool(db_provider, app)] - return controller @property @@ -87,15 +82,12 @@ class WorkflowToolProviderController(ToolProviderController): .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() ) - if not workflow: raise ValueError("workflow not found") - # fetch start node graph: Mapping = workflow.graph_dict features_dict: Mapping = workflow.features_dict features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) - parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) @@ -103,7 +95,6 @@ class WorkflowToolProviderController(ToolProviderController): return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore user = db_provider.user - workflow_tool_parameters = [] for parameter in parameters: variable = fetch_workflow_variable(parameter.name) @@ -113,13 +104,11 @@ class WorkflowToolProviderController(ToolProviderController): if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f"unsupported variable type {variable.type}") parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] - if variable.type == VariableEntityType.SELECT and variable.options: options = [ PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in variable.options ] - workflow_tool_parameters.append( ToolParameter( name=parameter.name, @@ -148,7 +137,6 @@ class WorkflowToolProviderController(ToolProviderController): ) else: raise ValueError("variable not found") - return WorkflowTool( workflow_as_tool_id=db_provider.id, entity=ToolEntity( @@ -181,13 +169,11 @@ class WorkflowToolProviderController(ToolProviderController): def get_tools(self, tenant_id: str) -> list[WorkflowTool]: """ fetch tools from database - :param tenant_id: the tenant id :return: the tools """ if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) .filter( @@ -196,32 +182,25 @@ class WorkflowToolProviderController(ToolProviderController): ) .first() ) - if not db_providers: return [] if not db_providers.app: raise ValueError("app not found") - app = db_providers.app if not app: raise ValueError("can not read app of workflow") - self.tools = [self._get_db_provider_tool(db_providers, app)] - return self.tools def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore """ get tool by name - :param tool_name: the name of the tool :return: the tool """ if self.tools is None: return None - for tool in self.tools: if tool.entity.identity.name == tool_name: return tool - return None diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 57c93d1d45..2d0c7b921b 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -26,9 +26,7 @@ class WorkflowTool(Tool): workflow_call_depth: int thread_pool_id: Optional[str] = None workflow_as_tool_id: str - label: str - """ Workflow tool. """ @@ -52,13 +50,11 @@ class WorkflowTool(Tool): self.workflow_call_depth = workflow_call_depth self.thread_pool_id = thread_pool_id self.label = label - super().__init__(entity=entity, runtime=runtime) def tool_provider_type(self) -> ToolProviderType: """ get the tool provider type - :return: the tool provider type """ return ToolProviderType.WORKFLOW @@ -76,16 +72,13 @@ class WorkflowTool(Tool): """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) - # transform the tool parameters tool_parameters, files = self._transform_args(tool_parameters=tool_parameters) - from core.app.apps.workflow.app_generator import WorkflowAppGenerator generator = WorkflowAppGenerator() assert self.runtime is not None assert self.runtime.invoke_from is not None - result = generator.generate( app_model=app, workflow=workflow, @@ -98,10 +91,8 @@ class WorkflowTool(Tool): ) assert isinstance(result, dict) data = result.get("data", {}) - if err := data.get("error"): raise ToolInvokeError(err) - outputs = data.get("outputs") if outputs is None: outputs = {} @@ -109,14 +100,12 @@ class WorkflowTool(Tool): outputs, files = self._extract_files(outputs) # type: ignore for file in files: yield self.create_file_message(file) # type: ignore - yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata - :return: the new tool """ return self.__class__( @@ -143,10 +132,8 @@ class WorkflowTool(Tool): ) else: workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() - if not workflow: raise ValueError("workflow not found or not published") - return workflow def _get_app(self, app_id: str) -> App: @@ -156,13 +143,11 @@ class WorkflowTool(Tool): app = db.session.query(App).filter(App.id == app_id).first() if not app: raise ValueError("app not found") - return app def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ transform the tool parameters - :param tool_parameters: the tool parameters :return: tool_parameters, files """ @@ -186,19 +171,16 @@ class WorkflowTool(Tool): file_dict["upload_file_id"] = file.related_id elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() - files.append(file_dict) except Exception: logger.exception(f"Failed to transform file {file}") else: parameters_result[parameter.name] = tool_parameters.get(parameter.name) - return parameters_result, files def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: """ extract files from the result - :return: the result, files """ files: list[File] = [] @@ -220,9 +202,7 @@ class WorkflowTool(Tool): tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), ) files.append(file) - result[key] = value - return result, files def _update_file_mapping(self, file_dict: dict) -> dict: diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e0372..09bde08ccd 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -12,7 +12,6 @@ from .types import SegmentType class Segment(BaseModel): model_config = ConfigDict(frozen=True) - value_type: SegmentType value: Any diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d82883..2df7517163 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -6,17 +6,13 @@ class SegmentType(StrEnum): STRING = "string" OBJECT = "object" SECRET = "secret" - FILE = "file" - ARRAY_ANY = "array[any]" ARRAY_STRING = "array[string]" ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" ARRAY_FILE = "array[file]" - NONE = "none" - GROUP = "group" def is_array_type(self): diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 00bbb37752..b081c1ea1f 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -103,10 +103,10 @@ class AgentNode(ToolNode): try: # convert tool messages - agent_thoughts = [] + agent_thoughts: list[dict[str, Any]] = [] from core.tools.entities.tool_entities import ToolInvokeMessage - + thought_log_message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LOG, message=ToolInvokeMessage.LogMessage( @@ -128,11 +128,10 @@ class AgentNode(ToolNode): ) from core.tools.entities.tool_entities import ToolInvokeMessage - + def enhanced_message_stream(): - yield thought_log_message - + yield from message_stream yield from self._transform_message( diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 795a16663e..cb2ead3adf 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -371,7 +371,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield agent_log # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output = json.copy() + json_output: list[dict[Any, Any]] | dict[str, Any] = json.copy() if agent_logs: if not json_output: json_output = {} @@ -381,11 +381,11 @@ class ToolNode(BaseNode[ToolNodeData]): elif isinstance(json_output, list): # If json is a list with multiple elements, create a dictionary containing all data json_output = {"data": json_output} - + # Ensure json_output is a dictionary type if not isinstance(json_output, dict): json_output = {"data": json_output} - + # Add agent_logs to json output json_output["agent_logs"] = [ { diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py index 4b12afb528..ed330aff0d 100644 --- a/api/factories/agent_factory.py +++ b/api/factories/agent_factory.py @@ -11,5 +11,4 @@ def get_plugin_agent_strategy( for agent_strategy in agent_provider.declaration.strategies: if agent_strategy.identity.name == agent_strategy_name: return PluginAgentStrategy(tenant_id, agent_strategy) - raise ValueError(f"Agent strategy {agent_strategy_name} not found") diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 25d1390492..3587b118a0 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -56,24 +56,20 @@ def build_from_mapping( strict_type_validation: bool = False, ) -> File: transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) - build_functions: dict[FileTransferMethod, Callable] = { FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.TOOL_FILE: _build_from_tool_file, } - build_func = build_functions.get(transfer_method) if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, strict_type_validation=strict_type_validation, ) - if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), file_extension=file.extension or "", @@ -81,7 +77,6 @@ def build_from_mapping( config=config, ): raise ValueError(f"File validation failed for file: {file.filename}") - return file @@ -103,7 +98,6 @@ def build_from_mappings( ) for mapping in mappings ] - if ( config # If image config is set. @@ -114,7 +108,6 @@ def build_from_mappings( raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") if config and config.number_limits and len(files) > config.number_limits: raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - return files @@ -137,21 +130,16 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(stmt) if row is None: raise ValueError("Invalid upload file") - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) specified_type = mapping.get("type", "custom") - if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type ) - return File( id=mapping.get("id"), filename=row.name, @@ -184,26 +172,20 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(stmt) if upload_file is None: raise ValueError("Invalid upload file") - detected_file_type = _standardize_file_type( extension="." + upload_file.extension, mime_type=upload_file.mime_type ) - specified_type = mapping.get("type") - if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type ) - return File( id=mapping.get("id"), filename=upload_file.name, @@ -220,14 +202,11 @@ def _build_from_remote_url( url = mapping.get("url") or mapping.get("remote_url") if not url: raise ValueError("Invalid file url") - mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - file_type = _standardize_file_type(extension=extension, mime_type=mime_type) if file_type.value != mapping.get("type", "custom"): raise ValueError("Detected file type does not match the specified type. Please verify the file.") - return File( id=mapping.get("id"), filename=filename, @@ -246,7 +225,6 @@ def _get_remote_file_info(url: str): file_size = -1 filename = url.split("/")[-1].split("?")[0] or "unknown_file" mime_type = mimetypes.guess_type(filename)[0] or "" - resp = ssrf_proxy.head(url, follow_redirects=True) resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: @@ -254,7 +232,6 @@ def _get_remote_file_info(url: str): filename = str(content_disposition.split("filename=")[-1].strip('"')) file_size = int(resp.headers.get("Content-Length", file_size)) mime_type = mime_type or str(resp.headers.get("Content-Type", "")) - return mime_type, filename, file_size @@ -273,23 +250,16 @@ def _build_from_tool_file( ) .first() ) - if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) - specified_type = mapping.get("type") - if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type ) - return File( id=mapping.get("id"), tenant_id=tenant_id, @@ -318,14 +288,12 @@ def _is_file_valid_with_config( and input_file_type != FileType.CUSTOM ): return False - if ( input_file_type == FileType.CUSTOM and config.allowed_file_extensions is not None and file_extension not in config.allowed_file_extensions ): return False - if input_file_type == FileType.IMAGE: if ( config.image_config @@ -335,7 +303,6 @@ def _is_file_valid_with_config( return False elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: return False - return True @@ -396,7 +363,6 @@ class StorageKeyLoader: UploadFile.id.in_(upload_file_ids), UploadFile.tenant_id == self._tenant_id, ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: @@ -409,16 +375,13 @@ class StorageKeyLoader: def load_storage_keys(self, files: Sequence[File]): """Loads storage keys for a sequence of files by retrieving the corresponding `UploadFile` or `ToolFile` records from the database based on their transfer method. - This method doesn't modify the input sequence structure but updates the `_storage_key` property of each file object by extracting the relevant key from its database record. - Performance note: This is a batched operation where database query count remains constant regardless of input size. However, for optimal performance, input sequences should contain fewer than 1000 files. For larger collections, split into smaller batches and process each batch separately. """ - upload_file_ids: list[uuid.UUID] = [] tool_file_ids: list[uuid.UUID] = [] for file in files: @@ -432,12 +395,10 @@ class StorageKeyLoader: ) raise ValueError(err_msg) model_id = uuid.UUID(related_model_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): upload_file_ids.append(model_id) elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_ids.append(model_id) - tool_files = self._load_tool_files(tool_file_ids) upload_files = self._load_upload_files(upload_file_ids) for file in files: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695e..edb8fa957b 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -84,7 +84,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") - result: Variable match value_type: case SegmentType.STRING: @@ -156,32 +155,24 @@ def build_segment(value: Any, /) -> Segment: def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. - This function creates a segment from a value while enforcing type compatibility with the specified segment_type. It provides stricter type validation compared to the standard build_segment function. - Args: segment_type: The expected SegmentType for the resulting segment value: The value to be converted into a segment - Returns: Segment: A segment instance of the appropriate type - Raises: TypeMismatchError: If the value type doesn't match the expected segment_type - Special Cases: - For empty list [] values, if segment_type is array[*], returns the corresponding array type - Type validation is performed before segment creation - Examples: >>> build_segment_with_type(SegmentType.STRING, "hello") StringSegment(value="hello") - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) ArrayStringSegment(value=[]) - >>> build_segment_with_type(SegmentType.STRING, 123) # Raises TypeMismatchError """ @@ -191,7 +182,6 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: return NoneSegment() else: raise TypeMismatchError(f"Expected {segment_type}, but got None") - # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: if segment_type == SegmentType.ARRAY_ANY: @@ -206,15 +196,12 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: return ArrayFileSegment(value=value) else: raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - # Build segment using existing logic to infer actual type inferred_segment = build_segment(value) inferred_type = inferred_segment.value_type - # Type compatibility checking if inferred_type == segment_type: return inferred_segment - # Type mismatch - raise error with descriptive message raise TypeMismatchError( f"Type mismatch: expected {segment_type}, but value '{value}' " @@ -234,11 +221,9 @@ def segment_to_variable( return segment name = name or selector[-1] id = id or str(uuid4()) - segment_type = type(segment) if segment_type not in SEGMENT_TO_VARIABLE_MAP: raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( Variable, diff --git a/api/migrations/env.py b/api/migrations/env.py index a5d815dcfd..f7ce8d9752 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -106,5 +106,4 @@ def run_migrations_online(): if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() - + run_migrations_online() \ No newline at end of file diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py index 5ae9e8769a..90e15ec054 100644 --- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -30,4 +30,4 @@ def downgrade(): batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) batch_op.drop_column('description') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py index 8cd4ec552b..8f95163451 100644 --- a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -30,4 +30,4 @@ def downgrade(): batch_op.drop_column('label') with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: - batch_op.drop_constraint('unique_tool_label_bind', type_='unique') + batch_op.drop_constraint('unique_tool_label_bind', type_='unique') \ No newline at end of file diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index 153861a71a..e5abeeaec7 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -36,4 +36,4 @@ def downgrade(): # ### commands auto generated by Alembic - please adjust! ## op.drop_table('tracing_app_configs') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py index a589f1f08b..dd7a341fa4 100644 --- a/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py +++ b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py @@ -48,4 +48,4 @@ def downgrade(): batch_op.drop_column('privacy_policy') op.drop_table('tool_conversation_variables') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py index 58863fe3a7..c67dbaed75 100644 --- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -29,4 +29,4 @@ def downgrade(): with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py index 8907f78117..4e2ff74c28 100644 --- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py +++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py @@ -31,4 +31,4 @@ def downgrade(): with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.drop_column('tenant_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/16830a790f0f_.py b/api/migrations/versions/16830a790f0f_.py index 38d6e4940a..54fd53896a 100644 --- a/api/migrations/versions/16830a790f0f_.py +++ b/api/migrations/versions/16830a790f0f_.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: batch_op.drop_column('current') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py index 6791cf4578..bede7e1031 100644 --- a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py +++ b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py @@ -76,4 +76,4 @@ def downgrade(): batch_op.drop_index('provider_model_tenant_id_provider_idx') op.drop_table('provider_models') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py b/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py index 7707148489..1108621d50 100644 --- a/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py +++ b/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: batch_op.drop_column('data_source_type') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py index 13a823f7ec..4c81b01a43 100644 --- a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py +++ b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py @@ -34,4 +34,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py index db966252f1..ce30380e81 100644 --- a/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py +++ b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py @@ -36,4 +36,4 @@ def downgrade(): type_=sa.VARCHAR(length=255), existing_nullable=True) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py index 16e1efd4ef..38b8ae9510 100644 --- a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py +++ b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py @@ -48,4 +48,4 @@ def downgrade(): batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx')) op.drop_table('workflow__conversation_variables') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py index eba78e2e77..d56eb75105 100644 --- a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('conversations', schema=None) as batch_op: batch_op.drop_column('dialogue_count') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py index ca2e410442..0f8d9b5653 100644 --- a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py +++ b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py @@ -48,4 +48,4 @@ def downgrade(): batch_op.drop_index('tidb_auth_bindings_active_idx') batch_op.drop_index('tidb_auth_bindings_status_idx') op.drop_table('tidb_auth_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py index d814666eef..4e5e13e548 100644 --- a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py +++ b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py @@ -36,4 +36,4 @@ def downgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('icon_type') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py index 3dc7fed818..5ce715ac19 100644 --- a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -25,4 +25,4 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py index e0066a302c..31e8e45090 100644 --- a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py +++ b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py @@ -49,4 +49,4 @@ def downgrade(): batch_op.drop_column("updated_by") batch_op.drop_column("created_by") - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py index 4406d51ed0..2d23a15b8e 100644 --- a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py +++ b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py @@ -42,4 +42,4 @@ def downgrade(): with op.batch_alter_table("apps", schema=None) as batch_op: batch_op.drop_column("use_icon_as_answer_icon") - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py index fd957eeafb..1451502a47 100644 --- a/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py +++ b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py @@ -33,4 +33,4 @@ def downgrade(): with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_column('parent_message_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py index 5337b340db..9e700a0d49 100644 --- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -45,4 +45,4 @@ def downgrade(): existing_type=sa.UUID(), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py index 3cb76e72c1..37c369f53e 100644 --- a/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py +++ b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py @@ -70,4 +70,4 @@ def downgrade(): batch_op.drop_index('external_knowledge_apis_name_idx') op.drop_table('external_knowledge_apis') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py index 71006679e1..b989d2747d 100644 --- a/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py +++ b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py @@ -36,4 +36,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py index 38a5cdf8e5..a8158956bb 100644 --- a/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py +++ b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py @@ -45,4 +45,4 @@ WHERE AND parent_message_id = '{UUID_NIL}' AND created_at >= '{v0_9_0_release_date}';""" op.execute(sql) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py index 00f2b15802..969cb4d06d 100644 --- a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -62,4 +62,4 @@ def downgrade(): with op.batch_alter_table("tool_files", schema=None) as batch_op: batch_op.drop_column("size") batch_op.drop_column("name") - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py index 9daf148bc4..71c90534fd 100644 --- a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py +++ b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py @@ -39,4 +39,4 @@ def downgrade(): batch_op.drop_index('whitelists_tenant_idx') op.drop_table('whitelists') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py b/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py index 51a0b1b211..1e514090c0 100644 --- a/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py +++ b/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py @@ -34,4 +34,4 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('account_plugin_permissions') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py index a749c8bddf..5221c05ec0 100644 --- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -28,4 +28,4 @@ def downgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('upload_files', schema=None) as batch_op: batch_op.drop_column('source_url') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py index 81a7978f73..cbd2fb67bf 100644 --- a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py +++ b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py @@ -49,4 +49,4 @@ def downgrade(): batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False) batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py index 222379a490..93d099d5e0 100644 --- a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py +++ b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py @@ -38,4 +38,4 @@ def downgrade(): existing_nullable=False, existing_server_default=sa.text("''::character varying")) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py index 9a4ccf352d..c33bd4796c 100644 --- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -64,4 +64,4 @@ def downgrade(): type_=sa.VARCHAR(length=255), nullable=True) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py index 117a7351cd..da414a4937 100644 --- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -70,4 +70,4 @@ def downgrade(): existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py b/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py index d94508edcf..c8ed64569a 100644 --- a/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py +++ b/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py @@ -28,4 +28,4 @@ def downgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_index('message_created_at_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py index 9238e5a0a8..0fc43447eb 100644 --- a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -52,4 +52,4 @@ def downgrade(): batch_op.drop_index('child_chunk_dataset_id_idx') op.drop_table('child_chunks') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py index 8c576339ba..92dbe747dd 100644 --- a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py +++ b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('workflow_runs', schema=None) as batch_op: batch_op.drop_column('exceptions_count') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py index 881a9e3c1e..c0ff4ce5d0 100644 --- a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -36,4 +36,4 @@ def downgrade(): sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py index ae9f2de9b1..0facd0ecc0 100644 --- a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -35,4 +35,4 @@ def downgrade(): # batch_op.drop_column('retry_index') pass - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py index adf6421e57..1f49d35dbe 100644 --- a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py +++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py @@ -41,4 +41,4 @@ def upgrade(): def downgrade(): # No downgrade needed as we don't want to restore the column - pass + pass \ No newline at end of file diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py index 6dadd4e4a8..6b7ea53c03 100644 --- a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -44,4 +44,4 @@ def downgrade(): batch_op.drop_index('dataset_auto_disable_log_created_atx') op.drop_table('dataset_auto_disable_logs') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py index 798c895863..7ca9a49c4d 100644 --- a/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py +++ b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py @@ -38,4 +38,4 @@ def downgrade(): existing_nullable=False, existing_server_default=sa.text('0')) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py b/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py index ef495be661..23f2fc2f22 100644 --- a/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py +++ b/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py @@ -40,4 +40,4 @@ def downgrade(): batch_op.drop_index('rate_limit_log_operation_idx') op.drop_table('rate_limit_logs') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py b/api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py index 877e3a5eed..70f2517864 100644 --- a/api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py +++ b/api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py @@ -87,4 +87,4 @@ def downgrade(): batch_op.drop_index('dataset_metadata_binding_dataset_idx') op.drop_table('dataset_metadata_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_03_03_0304-4413929e1ec2_extend_provider_name_column.py b/api/migrations/versions/2025_03_03_0304-4413929e1ec2_extend_provider_name_column.py index 4a62624bb8..2484d65677 100644 --- a/api/migrations/versions/2025_03_03_0304-4413929e1ec2_extend_provider_name_column.py +++ b/api/migrations/versions/2025_03_03_0304-4413929e1ec2_extend_provider_name_column.py @@ -36,4 +36,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_03_03_1436-ee79d9b1c156_add_marked_name_and_marked_comment_in_.py b/api/migrations/versions/2025_03_03_1436-ee79d9b1c156_add_marked_name_and_marked_comment_in_.py index 5189de40e4..4d67b3637e 100644 --- a/api/migrations/versions/2025_03_03_1436-ee79d9b1c156_add_marked_name_and_marked_comment_in_.py +++ b/api/migrations/versions/2025_03_03_1436-ee79d9b1c156_add_marked_name_and_marked_comment_in_.py @@ -26,4 +26,4 @@ def upgrade(): def downgrade(): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.drop_column('marked_comment') - batch_op.drop_column('marked_name') + batch_op.drop_column('marked_name') \ No newline at end of file diff --git a/api/migrations/versions/2025_03_07_0315-5511c782ee4c_extend_provider_column.py b/api/migrations/versions/2025_03_07_0315-5511c782ee4c_extend_provider_column.py index 0dc15ffd78..5a370a9808 100644 --- a/api/migrations/versions/2025_03_07_0315-5511c782ee4c_extend_provider_column.py +++ b/api/migrations/versions/2025_03_07_0315-5511c782ee4c_extend_provider_column.py @@ -61,4 +61,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py b/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py index 45904f0c80..78092e4489 100644 --- a/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py +++ b/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py @@ -40,4 +40,4 @@ def downgrade(): batch_op.drop_index('child_chunks_segment_idx') batch_op.drop_index('child_chunks_node_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py b/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py index 19f6c01655..0c6b4e1056 100644 --- a/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py +++ b/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: batch_op.drop_index(batch_op.f('workflow_conversation_variables_conversation_id_idx')) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py index 5bf394b21c..32981cec1f 100644 --- a/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py +++ b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py @@ -48,4 +48,4 @@ def downgrade(): # Dropping `workflow_draft_variables` also drops any index associated with it. op.drop_table("workflow_draft_variables") - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py index d7a5d116c9..9ac945e168 100644 --- a/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py +++ b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py @@ -57,4 +57,4 @@ def downgrade(): with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op: batch_op.drop_column('node_execution_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py index 29fef77798..92220f7bff 100644 --- a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py +++ b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py @@ -63,4 +63,4 @@ def downgrade(): batch_op.alter_column('sequence_number', nullable=False) batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py index f3eef4681e..d8aa50a329 100644 --- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: batch_op.drop_column('message_files') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py index 9816e92dd1..da55a87c1f 100644 --- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -47,4 +47,4 @@ def downgrade(): batch_op.drop_index('app_annotation_settings_app_idx') op.drop_table('app_annotation_settings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 99b7010612..7bd3748273 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('tracing') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py index e933623d1c..c9ca9139d2 100644 --- a/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py +++ b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('is_universal') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2c8af9671032_add_qa_document_language.py b/api/migrations/versions/2c8af9671032_add_qa_document_language.py index 1f0c145446..bb72ed3c5b 100644 --- a/api/migrations/versions/2c8af9671032_add_qa_document_language.py +++ b/api/migrations/versions/2c8af9671032_add_qa_document_language.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('documents', schema=None) as batch_op: batch_op.drop_column('doc_language') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py index b06a3530b8..3af70345e6 100644 --- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -33,4 +33,4 @@ def downgrade(): batch_op.drop_index('api_token_tenant_idx') batch_op.drop_column('tenant_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py index 6c13818463..da84253306 100644 --- a/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py +++ b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: batch_op.drop_column('tool_labels_str') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py index bf54c247ea..ef74a06828 100644 --- a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -39,4 +39,4 @@ def downgrade(): batch_op.drop_column('privacy_policy') op.drop_table('tool_label_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py b/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py index 5f11880683..88d215b354 100644 --- a/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py +++ b/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py @@ -59,4 +59,4 @@ def downgrade(): batch_op.drop_index('tag_bind_tag_id_idx') op.drop_table('tag_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py index 4fbc570303..5c60600647 100644 --- a/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py +++ b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py @@ -64,4 +64,4 @@ def downgrade(): op.drop_table('tool_published_apps') op.drop_table('tool_builtin_providers') op.drop_table('tool_api_providers') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/408176b91ad3_add_max_active_requests.py b/api/migrations/versions/408176b91ad3_add_max_active_requests.py index c19a68586f..bf9ee5d564 100644 --- a/api/migrations/versions/408176b91ad3_add_max_active_requests.py +++ b/api/migrations/versions/408176b91ad3_add_max_active_requests.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('max_active_requests') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py index f388b99b90..0fd5f442c1 100644 --- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -45,4 +45,4 @@ def downgrade(): existing_type=postgresql.UUID(), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py index b47dd3c8ab..83a988ec2f 100644 --- a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py +++ b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: batch_op.drop_column('score') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py index b37928d3c0..a7e417758f 100644 --- a/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py +++ b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py @@ -36,4 +36,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/4823da1d26cf_add_tool_file.py b/api/migrations/versions/4823da1d26cf_add_tool_file.py index 1a473a10fe..9b2cacdc8e 100644 --- a/api/migrations/versions/4823da1d26cf_add_tool_file.py +++ b/api/migrations/versions/4823da1d26cf_add_tool_file.py @@ -34,4 +34,4 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tool_files') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py index 2405021856..1025ba7a1b 100644 --- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -32,4 +32,4 @@ def downgrade(): existing_type=postgresql.UUID(), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py index 178bd24e3c..49a3a72a43 100644 --- a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py +++ b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py @@ -42,4 +42,4 @@ def downgrade(): nullable=False, existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py index 3be4ba4f2a..76a7b1e59f 100644 --- a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -123,4 +123,4 @@ def downgrade(): batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx') op.drop_table('load_balancing_model_configs') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py b/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py index c09cf2af60..ab94f1d452 100644 --- a/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py +++ b/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('sites', schema=None) as batch_op: batch_op.drop_column('show_workflow_steps') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py index c0f4af5a00..d56881bc8e 100644 --- a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py +++ b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py @@ -32,4 +32,4 @@ def downgrade(): batch_op.create_unique_constraint('embedding_hash_idx', ['hash']) batch_op.drop_column('model_name') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/53bf8af60645_update_model.py b/api/migrations/versions/53bf8af60645_update_model.py index 3d0928d013..27721108d2 100644 --- a/api/migrations/versions/53bf8af60645_update_model.py +++ b/api/migrations/versions/53bf8af60645_update_model.py @@ -38,4 +38,4 @@ def downgrade(): existing_nullable=False, existing_server_default=sa.text("''::character varying")) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index 299f442de9..4a16c66d53 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -32,4 +32,4 @@ def downgrade(): existing_type=postgresql.UUID(), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/5fda94355fce_custom_disclaimer.py b/api/migrations/versions/5fda94355fce_custom_disclaimer.py index 73bcdc4500..aff2061e6e 100644 --- a/api/migrations/versions/5fda94355fce_custom_disclaimer.py +++ b/api/migrations/versions/5fda94355fce_custom_disclaimer.py @@ -42,4 +42,4 @@ def downgrade(): with op.batch_alter_table('recommended_apps', schema=None) as batch_op: batch_op.drop_column('custom_disclaimer') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/614f77cecc48_add_last_active_at.py b/api/migrations/versions/614f77cecc48_add_last_active_at.py index 182f8f89f1..8dff00c69e 100644 --- a/api/migrations/versions/614f77cecc48_add_last_active_at.py +++ b/api/migrations/versions/614f77cecc48_add_last_active_at.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('accounts', schema=None) as batch_op: batch_op.drop_column('last_active_at') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/63f9175e515b_merge_branches.py b/api/migrations/versions/63f9175e515b_merge_branches.py index 0623659941..0eeeb78e7e 100644 --- a/api/migrations/versions/63f9175e515b_merge_branches.py +++ b/api/migrations/versions/63f9175e515b_merge_branches.py @@ -19,4 +19,4 @@ def upgrade(): def downgrade(): - pass + pass \ No newline at end of file diff --git a/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py index 73242908f4..f70d435c95 100644 --- a/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py +++ b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py @@ -29,4 +29,4 @@ def downgrade(): with op.batch_alter_table('workflow_runs', schema=None) as batch_op: batch_op.drop_index('workflow_run_tenant_app_sequence_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py index b0fb3deac6..bfc98f4459 100644 --- a/api/migrations/versions/64b051264f32_init.py +++ b/api/migrations/versions/64b051264f32_init.py @@ -794,4 +794,4 @@ def downgrade(): op.drop_table('account_integrates') op.execute('DROP EXTENSION IF EXISTS "uuid-ossp";') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py index 55824945da..f9459112f1 100644 --- a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -32,4 +32,4 @@ def downgrade(): batch_op.drop_index('workflow_node_execution_id_idx') batch_op.drop_column('node_execution_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py index da27dd4426..bcaeecc59a 100644 --- a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py +++ b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py @@ -51,4 +51,4 @@ def downgrade(): batch_op.drop_index('dataset_retriever_resource_message_id_idx') op.drop_table('dataset_retriever_resources') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py index 4fa322f693..d8f5283a73 100644 --- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -44,4 +44,4 @@ def downgrade(): batch_op.drop_index('provider_model_name_idx') op.drop_table('dataset_collection_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py index 7445f664cd..e4517ab10c 100644 --- a/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py +++ b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py @@ -29,4 +29,4 @@ def downgrade(): with op.batch_alter_table('embeddings', schema=None) as batch_op: batch_op.drop_index('created_at_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py index 498b46e3c4..5730582ac5 100644 --- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -30,4 +30,4 @@ def downgrade(): batch_op.drop_column('annotation_content') batch_op.drop_column('annotation_question') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py index c5d8c3d88d..24d45c728d 100644 --- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('retriever_resource') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py index 2ba0e13caa..633095af8f 100644 --- a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -64,4 +64,4 @@ def downgrade(): batch_op.drop_index('data_source_api_key_auth_binding_provider_idx') op.drop_table('data_source_api_key_auth_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py index f09a682f28..840a95b9f4 100644 --- a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -39,4 +39,4 @@ def upgrade(): def downgrade(): op.drop_table('tool_workflow_providers') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py index 881ffec61d..f8dcb406bc 100644 --- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -41,4 +41,4 @@ def downgrade(): batch_op.drop_column('sensitive_word_avoidance') op.drop_table('tool_providers') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py index 865572f3a7..9a206ceb26 100644 --- a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py +++ b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py @@ -39,4 +39,4 @@ def downgrade(): batch_op.drop_index('idx_dataset_permissions_dataset_id') batch_op.drop_index('idx_dataset_permissions_account_id') op.drop_table('dataset_permissions') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py index 5a8476501b..7bb4e9287a 100644 --- a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py +++ b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py @@ -39,4 +39,4 @@ def downgrade(): batch_op.drop_column('answer_price_unit') batch_op.drop_column('message_price_unit') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py index f7625bff8c..dc7994fb11 100644 --- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('tenants', schema=None) as batch_op: batch_op.drop_column('custom_config') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py index 0fad39fa57..c44944f4ba 100644 --- a/api/migrations/versions/89c7899ca936_.py +++ b/api/migrations/versions/89c7899ca936_.py @@ -34,4 +34,4 @@ def downgrade(): type_=sa.VARCHAR(length=255), existing_nullable=True) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py index f4c4ebb51b..c9e60e414d 100644 --- a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py +++ b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py @@ -29,4 +29,4 @@ def downgrade(): batch_op.drop_index('user_id_idx') batch_op.drop_index('conversation_id_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py index 849103b071..d745edeb8e 100644 --- a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py +++ b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py @@ -39,4 +39,4 @@ def downgrade(): batch_op.drop_column('updated_by') batch_op.drop_column('answer') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py index ec2336da4d..1ccac858ee 100644 --- a/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py +++ b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py @@ -30,4 +30,4 @@ def downgrade(): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.drop_column('environment_variables') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py index 6cafc198aa..ee38be4376 100644 --- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: batch_op.drop_column('credentials_str') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py index 01d5631510..4488ccc992 100644 --- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -56,4 +56,4 @@ def downgrade(): batch_op.drop_index('message_file_created_by_idx') op.drop_table('message_files') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py index 207a9c841f..452db69fea 100644 --- a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -42,4 +42,4 @@ def downgrade(): op.drop_table('api_based_extensions') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py index 92f41f0abd..dfd5baeb89 100644 --- a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -24,4 +24,4 @@ def upgrade(): def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: - batch_op.drop_column('version') + batch_op.drop_column('version') \ No newline at end of file diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py index c7a98b4ac6..64ed6630f9 100644 --- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -42,4 +42,4 @@ def downgrade(): batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by'], unique=False) batch_op.drop_column('created_by_role') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py index 968906bdd7..df74772002 100644 --- a/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py +++ b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('message_files', schema=None) as batch_op: batch_op.drop_column('belongs_to') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py index 3014978110..a47dcb0504 100644 --- a/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py +++ b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py @@ -32,4 +32,4 @@ def downgrade(): batch_op.create_index('recommended_app_is_listed_idx', ['is_listed'], unique=False) batch_op.drop_column('language') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py index acb6812434..edb46593ca 100644 --- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('speech_to_text') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py index 1ee01381d8..ea55074ae0 100644 --- a/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py +++ b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py @@ -31,4 +31,4 @@ def downgrade(): batch_op.drop_constraint('embedding_hash_idx', type_='unique') batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) batch_op.drop_column('provider_name') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py index 62d6faeb1d..ec8dbb9933 100644 --- a/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py +++ b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py @@ -33,4 +33,4 @@ def downgrade(): with op.batch_alter_table('document_segments', schema=None) as batch_op: batch_op.drop_index('document_segment_tenant_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 5dcb630aed..32407f37ea 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('external_data_tools') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py index eee41bf4e0..a90172d575 100644 --- a/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py +++ b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('dataset_query_variable') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/ad472b61a054_add_api_provider_icon.py b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py index 0ddaf1eb0a..6a953c69ce 100644 --- a/api/migrations/versions/ad472b61a054_add_api_provider_icon.py +++ b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: batch_op.drop_column('icon') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py index 29ba859f2b..3f4fdbe26c 100644 --- a/api/migrations/versions/b24be59fbb04_.py +++ b/api/migrations/versions/b24be59fbb04_.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('text_to_speech') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py b/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py index c9a6a5a5a7..2592e918aa 100644 --- a/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py +++ b/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py @@ -29,4 +29,4 @@ def downgrade(): with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_index('message_workflow_run_id_idx') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 966f86c05f..be074b4af3 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -139,4 +139,4 @@ def downgrade(): batch_op.drop_index('workflow_app_log_app_idx') op.drop_table('workflow_app_logs') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py index 5682eff030..87994adeea 100644 --- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -34,4 +34,4 @@ def downgrade(): batch_op.drop_column('chat_prompt_config') batch_op.drop_column('prompt_type') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py index ee81fdab28..a84c807929 100644 --- a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -38,4 +38,4 @@ def downgrade(): existing_type=sa.VARCHAR(length=255), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py b/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py index dd5a7495e4..df6569fb3c 100644 --- a/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py +++ b/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py @@ -32,4 +32,4 @@ def downgrade(): batch_op.drop_column('chat_color_theme_inverted') batch_op.drop_column('chat_color_theme') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py index dfa1517462..2ff0460bda 100644 --- a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py +++ b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py @@ -49,4 +49,4 @@ def downgrade(): batch_op.drop_index('provider_order_tenant_provider_idx') op.drop_table('provider_orders') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index f87819c367..5daa8b2a4a 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -41,4 +41,4 @@ def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('trace_app_config') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/c3311b089690_add_tool_meta.py b/api/migrations/versions/c3311b089690_add_tool_meta.py index e075535b0d..e9f6d8a72b 100644 --- a/api/migrations/versions/c3311b089690_add_tool_meta.py +++ b/api/migrations/versions/c3311b089690_add_tool_meta.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: batch_op.drop_column('tool_meta_str') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py b/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py index 95fb8f5d0e..1c2f1ca67d 100644 --- a/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py +++ b/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py @@ -46,4 +46,4 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tool_model_invokes') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py index aefbe43f14..0ea03cbaae 100644 --- a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -67,4 +67,4 @@ def downgrade(): existing_type=sa.VARCHAR(length=255), nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py b/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py index 89355e57ad..9e33c37e9f 100644 --- a/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py +++ b/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('conversations', schema=None) as batch_op: batch_op.drop_column('is_deleted') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index c18126286c..e8ef1a7e9a 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -111,4 +111,4 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### pass - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/dfb3b7f477da_add_tool_index.py b/api/migrations/versions/dfb3b7f477da_add_tool_index.py index e14a65a1ff..930fc6d477 100644 --- a/api/migrations/versions/dfb3b7f477da_add_tool_index.py +++ b/api/migrations/versions/dfb3b7f477da_add_tool_index.py @@ -33,4 +33,4 @@ def downgrade(): with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: batch_op.drop_constraint('unique_api_tool_provider', type_='unique') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py index 32902c8eb0..dcaab4c70b 100644 --- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -76,4 +76,4 @@ def downgrade(): batch_op.drop_index('app_annotation_hit_histories_account_idx') op.drop_table('app_annotation_hit_histories') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py b/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py index 08f994a41f..685a2063ea 100644 --- a/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py +++ b/api/migrations/versions/e2eacc9a1b63_add_status_for_message.py @@ -40,4 +40,4 @@ def downgrade(): with op.batch_alter_table('conversations', schema=None) as batch_op: batch_op.drop_column('invoke_from') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py index 3d7dd1fabf..bbc531b1da 100644 --- a/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py +++ b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py @@ -43,4 +43,4 @@ def downgrade(): batch_op.drop_index('source_binding_tenant_id_idx') op.drop_table('data_source_bindings') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py index 627366b36d..42bc2fa40a 100644 --- a/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py +++ b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py @@ -42,4 +42,4 @@ def downgrade(): type_=sa.INTEGER(), existing_nullable=True) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py index 875683d68e..56dd58e356 100644 --- a/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py +++ b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py @@ -30,4 +30,4 @@ def downgrade(): batch_op.drop_column('embedding_model_provider') batch_op.drop_column('embedding_model') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py b/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py index 434531b6c8..b9522d01a7 100644 --- a/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py +++ b/api/migrations/versions/eeb2e349e6ac_increase_max_model_name_length.py @@ -50,4 +50,4 @@ def downgrade(): type_=sa.VARCHAR(length=40), existing_nullable=False) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/f25003750af4_add_created_updated_at.py b/api/migrations/versions/f25003750af4_add_created_updated_at.py index 178eaf2380..313a1f0d1d 100644 --- a/api/migrations/versions/f25003750af4_add_created_updated_at.py +++ b/api/migrations/versions/f25003750af4_add_created_updated_at.py @@ -30,4 +30,4 @@ def downgrade(): batch_op.drop_column('updated_at') batch_op.drop_column('created_at') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py index dc9392a92c..0d84c12839 100644 --- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -31,4 +31,4 @@ def downgrade(): batch_op.drop_index('app_annotation_hit_histories_message_idx') batch_op.drop_column('message_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py index 3e5ae0d67d..dc60e7c991 100644 --- a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -28,4 +28,4 @@ def downgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('description') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py index 52495be60a..6500a79ff8 100644 --- a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py +++ b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py @@ -40,4 +40,4 @@ def downgrade(): sa.PrimaryKeyConstraint('id', name='sessions_pkey'), sa.UniqueConstraint('session_id', name='sessions_session_id_key') ) - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py index 6f76a361d9..5057088ac4 100644 --- a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py +++ b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py @@ -47,4 +47,4 @@ def downgrade(): with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.drop_index('idx_dataset_permissions_tenant_id') - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py index f6271bda47..2349874b3f 100644 --- a/api/models/_workflow_exc.py +++ b/api/models/_workflow_exc.py @@ -3,7 +3,6 @@ class WorkflowDataError(Exception): """Base class for all workflow data related exceptions. - This should be used to indicate issues with workflow data integrity, such as no `graph` configuration, missing `nodes` field in `graph` configuration, or similar issues. diff --git a/api/models/account.py b/api/models/account.py index 7ffeefa980..eca0024445 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -83,7 +83,6 @@ class AccountStatus(enum.StrEnum): class Account(UserMixin, Base): __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) @@ -138,10 +137,8 @@ class Account(UserMixin, Base): .one_or_none() ), ) - if not tenant_account_join: return - tenant, join = tenant_account_join self.role = join.role self._current_tenant = tenant @@ -195,7 +192,6 @@ class TenantStatus(enum.StrEnum): class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) @@ -229,7 +225,6 @@ class TenantAccountJoin(Base): db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) @@ -247,7 +242,6 @@ class AccountIntegrate(Base): db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) @@ -264,7 +258,6 @@ class InvitationCode(Base): db.Index("invitation_codes_batch_idx", "batch"), db.Index("invitation_codes_code_idx", "code", "status"), ) - id = db.Column(db.Integer, nullable=False) batch = db.Column(db.String(255), nullable=False) code = db.Column(db.String(32), nullable=False) @@ -292,7 +285,6 @@ class TenantPluginPermission(Base): db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column( diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 5a70e18622..94fc3c6a06 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -20,7 +20,6 @@ class APIBasedExtension(Base): db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) diff --git a/api/models/dataset.py b/api/models/dataset.py index 1ec27203a0..97480d7fdf 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -41,10 +41,8 @@ class Dataset(Base): db.Index("dataset_tenant_idx", "tenant_id"), db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) @@ -71,7 +69,6 @@ class Dataset(Base): ) if dataset_keyword_table: return dataset_keyword_table - return None @property @@ -176,7 +173,6 @@ class Dataset(Base): ) .all() ) - return tags or [] @property @@ -205,7 +201,6 @@ class Dataset(Base): @property def doc_metadata(self): dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() - doc_metadata = [ { "id": dataset_metadata.id, @@ -264,14 +259,12 @@ class DatasetProcessRule(Base): db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES: dict[str, Any] = { @@ -307,7 +300,6 @@ class Document(Base): db.Index("document_tenant_idx", "tenant_id"), db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), ) - # initial fields id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) @@ -322,35 +314,27 @@ class Document(Base): created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - # start processing processing_started_at = db.Column(db.DateTime, nullable=True) - # parsing file_id = db.Column(db.Text, nullable=True) word_count = db.Column(db.Integer, nullable=True) parsing_completed_at = db.Column(db.DateTime, nullable=True) - # cleaning cleaning_completed_at = db.Column(db.DateTime, nullable=True) - # split splitting_completed_at = db.Column(db.DateTime, nullable=True) - # indexing tokens = db.Column(db.Integer, nullable=True) indexing_latency = db.Column(db.Float, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) - # pause is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) - # error error = db.Column(db.Text, nullable=True) stopped_at = db.Column(db.DateTime, nullable=True) - # basic fields indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @@ -365,7 +349,6 @@ class Document(Base): doc_metadata = db.Column(JSONB, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @property @@ -394,7 +377,6 @@ class Document(Base): data_source_info_dict = json.loads(self.data_source_info) except JSONDecodeError: data_source_info_dict = {} - return data_source_info_dict return None @@ -488,7 +470,6 @@ class Document(Base): metadata_list.append(metadata_dict) # deal built-in fields metadata_list.extend(self.get_built_in_fields()) - return metadata_list return None @@ -650,7 +631,6 @@ class DocumentSegment(Base): db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), db.Index("document_segment_tenant_idx", "tenant_id"), ) - # initial fields id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) @@ -661,12 +641,10 @@ class DocumentSegment(Base): answer = db.Column(db.Text, nullable=True) word_count = db.Column(db.Integer, nullable=False) tokens = db.Column(db.Integer, nullable=False) - # indexing fields keywords = db.Column(db.JSON, nullable=True) index_node_id = db.Column(db.String(255), nullable=True) index_node_hash = db.Column(db.String(255), nullable=True) - # basic fields hit_count = db.Column(db.Integer, nullable=False, default=0) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @@ -748,7 +726,6 @@ class DocumentSegment(Base): def get_sign_content(self): signed_urls = [] text = self.content - # For data before v0.10.0 pattern = r"/files/([a-f0-9\-]+)/image-preview" matches = re.finditer(pattern, text) @@ -760,11 +737,9 @@ class DocumentSegment(Base): secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" signed_url = f"{match.group(0)}?{params}" signed_urls.append((match.start(), match.end(), signed_url)) - # For data after v0.10.0 pattern = r"/files/([a-f0-9\-]+)/file-preview" matches = re.finditer(pattern, text) @@ -776,17 +751,14 @@ class DocumentSegment(Base): secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" signed_url = f"{match.group(0)}?{params}" signed_urls.append((match.start(), match.end(), signed_url)) - # Reconstruct the text with signed URLs offset = 0 for start, end, signed_url in signed_urls: text = text[: start + offset] + signed_url + text[end + offset :] offset += len(signed_url) - (end - start) - return text @@ -798,7 +770,6 @@ class ChildChunk(Base): db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), db.Index("child_chunks_segment_idx", "segment_id"), ) - # initial fields id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) @@ -839,7 +810,6 @@ class AppDatasetJoin(Base): db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -856,7 +826,6 @@ class DatasetQuery(Base): db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) @@ -873,7 +842,6 @@ class DatasetKeywordTable(Base): db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) @@ -919,7 +887,6 @@ class Embedding(Base): db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) model_name = db.Column( db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") @@ -942,7 +909,6 @@ class DatasetCollectionBinding(Base): db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) @@ -991,7 +957,6 @@ class DatasetPermission(Base): db.Index("idx_dataset_permissions_account_id", "account_id"), db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) @@ -1007,7 +972,6 @@ class ExternalKnowledgeApis(Base): db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), db.Index("external_knowledge_apis_name_idx", "name"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) description = db.Column(db.String(255), nullable=False) @@ -1049,7 +1013,6 @@ class ExternalKnowledgeApis(Base): dataset_bindings = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) - return dataset_bindings @@ -1062,7 +1025,6 @@ class ExternalKnowledgeBindings(Base): db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) external_knowledge_api_id = db.Column(StringUUID, nullable=False) @@ -1082,7 +1044,6 @@ class DatasetAutoDisableLog(Base): db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), db.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1098,7 +1059,6 @@ class RateLimitLog(Base): db.Index("rate_limit_log_tenant_idx", "tenant_id"), db.Index("rate_limit_log_operation_idx", "operation"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) subscription_plan = db.Column(db.String(255), nullable=False) @@ -1113,7 +1073,6 @@ class DatasetMetadata(Base): db.Index("dataset_metadata_tenant_idx", "tenant_id"), db.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1134,7 +1093,6 @@ class DatasetMetadataBinding(Base): db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), db.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) diff --git a/api/models/engine.py b/api/models/engine.py index 05c1cacdcb..be3acc273a 100644 --- a/api/models/engine.py +++ b/api/models/engine.py @@ -8,9 +8,7 @@ POSTGRES_INDEXES_NAMING_CONVENTION = { "fk": "%(table_name)s_%(column_0_name)s_fkey", "pk": "%(table_name)s_pkey", } - metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) - # ****** IMPORTANT NOTICE ****** # # NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the diff --git a/api/models/model.py b/api/models/model.py index 93737043d5..39072fc466 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -13,7 +13,6 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus if TYPE_CHECKING: from models.workflow import Workflow - import sqlalchemy as sa from flask import request from flask_login import UserMixin @@ -39,7 +38,6 @@ if TYPE_CHECKING: class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = db.Column(db.String(255), nullable=False) setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -56,7 +54,6 @@ class AppMode(StrEnum): def value_of(cls, value: str) -> "AppMode": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -74,7 +71,6 @@ class IconType(Enum): class App(Base): __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) @@ -121,7 +117,6 @@ class App(Base): def app_model_config(self): if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() - return None @property @@ -130,7 +125,6 @@ class App(Base): from .workflow import Workflow return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() - return None @property @@ -149,7 +143,6 @@ class App(Base): return False if not app_model_config.agent_mode: return False - if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: @@ -162,7 +155,6 @@ class App(Base): def mode_compatible_with_agent(self) -> str: if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return str(self.mode) @property @@ -174,16 +166,12 @@ class App(Base): app_model_config = self.app_model_config if not app_model_config: return [] - if not app_model_config.agent_mode: return [] - agent_mode = app_model_config.agent_mode_dict tools = agent_mode.get("tools", []) - api_provider_ids: list[str] = [] builtin_provider_ids: list[GenericProviderID] = [] - for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: @@ -203,16 +191,12 @@ class App(Base): is_hardcoded = True except Exception: is_hardcoded = False - provider_id = GenericProviderID(provider_id, is_hardcoded) except Exception: continue - builtin_provider_ids.append(provider_id) - if not api_provider_ids and not builtin_provider_ids: return [] - with Session(db.engine) as session: if api_provider_ids: existing_api_providers = [ @@ -224,7 +208,6 @@ class App(Base): ] else: existing_api_providers = [] - if builtin_provider_ids: # get the non-hardcoded builtin providers non_hardcoded_builtin_providers = [ @@ -241,19 +224,15 @@ class App(Base): ] else: existence = [] - existing_builtin_providers = { provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools = [] - for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == ToolProviderType.API.value: if uuid.UUID(provider_id) not in existing_api_providers: deleted_tools.append( @@ -263,10 +242,8 @@ class App(Base): "provider_id": provider_id, } ) - if provider_type == ToolProviderType.BUILT_IN.value: generic_provider_id = GenericProviderID(provider_id) - if not existing_builtin_providers[generic_provider_id.provider_name]: deleted_tools.append( { @@ -275,7 +252,6 @@ class App(Base): "provider_id": provider_id, # use the original one } ) - return deleted_tools @property @@ -291,7 +267,6 @@ class App(Base): ) .all() ) - return tags or [] @property @@ -300,14 +275,12 @@ class App(Base): account = db.session.query(Account).filter(Account.id == self.created_by).first() if account: return account.name - return None class AppModelConfig(Base): __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) @@ -379,7 +352,6 @@ class AppModelConfig(Base): collection_binding_detail = annotation_setting.collection_binding_detail if not collection_binding_detail: raise ValueError("Collection binding detail not found") - return { "id": annotation_setting.id, "enabled": True, @@ -389,7 +361,6 @@ class AppModelConfig(Base): "embedding_model_name": collection_binding_detail.model_name, }, } - else: return {"enabled": False} @@ -550,7 +521,6 @@ class AppModelConfig(Base): dataset_configs=self.dataset_configs, file_upload=self.file_upload, ) - return new_app_model_config @@ -561,7 +531,6 @@ class RecommendedApp(Base): db.Index("recommended_app_app_id_idx", "app_id"), db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) @@ -590,7 +559,6 @@ class InstalledApp(Base): db.Index("installed_app_app_id_idx", "app_id"), db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) @@ -625,7 +593,6 @@ class Conversation(Base): db.PrimaryKeyConstraint("id", name="conversation_pkey"), db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -640,13 +607,11 @@ class Conversation(Base): system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) - # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) invoke_from = db.Column(db.String(255), nullable=True) - # ref: ConversationSource. from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) @@ -656,18 +621,15 @@ class Conversation(Base): dialogue_count: Mapped[int] = mapped_column(default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): inputs = self._inputs.copy() - # Convert file mapping to File object for key, value in inputs.items(): # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. @@ -689,7 +651,6 @@ class Conversation(Base): elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: item["upload_file_id"] = item["related_id"] inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) - return inputs @inputs.setter @@ -706,7 +667,6 @@ class Conversation(Base): def model_config(self): model_config = {} app_model_config: Optional[AppModelConfig] = None - if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -714,7 +674,6 @@ class Conversation(Base): else: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) @@ -727,10 +686,8 @@ class Conversation(Base): ) if app_model_config: model_config = app_model_config.to_dict() - model_config["model_id"] = self.model_id model_config["provider"] = self.model_provider - return model_config @property @@ -767,7 +724,6 @@ class Conversation(Base): ) .count() ) - dislike = ( db.session.query(MessageFeedback) .filter( @@ -777,7 +733,6 @@ class Conversation(Base): ) .count() ) - return {"like": like, "dislike": dislike} @property @@ -791,7 +746,6 @@ class Conversation(Base): ) .count() ) - dislike = ( db.session.query(MessageFeedback) .filter( @@ -801,7 +755,6 @@ class Conversation(Base): ) .count() ) - return {"like": like, "dislike": dislike} @property @@ -814,11 +767,9 @@ class Conversation(Base): WorkflowExecutionStatus.STOPPED: 0, WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, } - for message in messages: if message.workflow_run: status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 - return ( { "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], @@ -848,7 +799,6 @@ class Conversation(Base): end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first() if end_user: return end_user.session_id - return None @property @@ -857,7 +807,6 @@ class Conversation(Base): account = db.session.query(Account).filter(Account.id == self.from_account_id).first() if account: return account.name - return None @property @@ -903,7 +852,6 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) @@ -975,21 +923,15 @@ class Message(Base): def re_sign_file_url_answer(self) -> str: if not self.answer: return self.answer - pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" matches = re.findall(pattern, self.answer) - if not matches: return self.answer - urls = [match[0] for match in matches] - # remove duplicate urls urls = list(set(urls)) - if not urls: return self.answer - re_sign_file_url_answer = self.answer for url in urls: if "files/tools" in url: @@ -998,9 +940,7 @@ class Message(Base): result = re.search(tool_file_id_pattern, url) if not result: continue - tool_file_id = result.group(1) - # get extension if "." in tool_file_id: split_result = tool_file_id.split(".") @@ -1010,10 +950,8 @@ class Message(Base): tool_file_id = split_result[0] else: extension = ".bin" - if not tool_file_id: continue - sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) elif "file-preview" in url: # get upload file id @@ -1021,7 +959,6 @@ class Message(Base): result = re.search(upload_file_id_pattern, url) if not result: continue - upload_file_id = result.group(1) if not upload_file_id: continue @@ -1042,7 +979,6 @@ class Message(Base): if "as_attachment" in url: sign_url += "&as_attachment=true" re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) - return re_sign_file_url_answer @property @@ -1094,7 +1030,6 @@ class Message(Base): return ( db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() ) - return None @property @@ -1126,7 +1061,6 @@ class Message(Base): current_app = db.session.query(App).filter(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") - files = [] for message_file in message_files: if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: @@ -1173,12 +1107,10 @@ class Message(Base): f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" ) files.append(file) - result = [ {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} for (file, message_file) in zip(files, message_files) ] - db.session.commit() return result @@ -1188,7 +1120,6 @@ class Message(Base): from .workflow import WorkflowRun return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() - return None def to_dict(self) -> dict: @@ -1247,7 +1178,6 @@ class MessageFeedback(Base): db.Index("message_feedback_message_idx", "message_id", "from_source"), db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) @@ -1330,7 +1260,6 @@ class MessageAnnotation(Base): db.Index("message_annotation_conversation_idx", "conversation_id"), db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) @@ -1362,7 +1291,6 @@ class AppAnnotationHitHistory(Base): db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False) @@ -1397,7 +1325,6 @@ class AppAnnotationSetting(Base): db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) @@ -1425,7 +1352,6 @@ class OperationLog(Base): db.PrimaryKeyConstraint("id", name="operation_log_pkey"), db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) @@ -1443,7 +1369,6 @@ class EndUser(Base, UserMixin): db.Index("end_user_session_id_idx", "session_id", "type"), db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) @@ -1463,7 +1388,6 @@ class Site(Base): db.Index("site_app_id_idx", "app_id"), db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) @@ -1505,7 +1429,6 @@ class Site(Base): result = generate_string(n) while db.session.query(Site).filter(Site.code == result).count() > 0: result = generate_string(n) - return result @property @@ -1521,7 +1444,6 @@ class ApiToken(Base): db.Index("api_token_token_idx", "token", "type"), db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) @@ -1545,7 +1467,6 @@ class UploadFile(Base): db.PrimaryKeyConstraint("id", name="upload_file_pkey"), db.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -1607,7 +1528,6 @@ class ApiRequest(Base): db.PrimaryKeyConstraint("id", name="api_request_pkey"), db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) @@ -1624,7 +1544,6 @@ class MessageChain(Base): db.PrimaryKeyConstraint("id", name="message_chain_pkey"), db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) @@ -1640,7 +1559,6 @@ class MessageAgentThought(Base): db.Index("message_agent_thought_message_id_idx", "message_id"), db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) @@ -1753,7 +1671,6 @@ class DatasetRetrieverResource(Base): db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) @@ -1781,9 +1698,7 @@ class Tag(Base): db.Index("tag_type_idx", "type"), db.Index("tag_name_idx", "name"), ) - TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) @@ -1799,7 +1714,6 @@ class TagBinding(Base): db.Index("tag_bind_target_id_idx", "target_id"), db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) @@ -1814,7 +1728,6 @@ class TraceAppConfig(Base): db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90f..cf1b992a96 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -25,10 +25,8 @@ class ProviderType(Enum): class ProviderQuotaType(Enum): PAID = "paid" """hosted paid quota""" - FREE = "free" """third-party free quota""" - TRIAL = "trial" """hosted trial quota""" @@ -53,7 +51,6 @@ class Provider(Base): "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -63,13 +60,11 @@ class Provider(Base): encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type: Mapped[Optional[str]] = mapped_column( db.String(40), nullable=True, server_default=text("''::character varying") ) quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -110,7 +105,6 @@ class ProviderModel(Base): "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -128,7 +122,6 @@ class TenantDefaultModel(Base): db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -144,7 +137,6 @@ class TenantPreferredModelProvider(Base): db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -159,7 +151,6 @@ class ProviderOrder(Base): db.PrimaryKeyConstraint("id", name="provider_order_pkey"), db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -190,7 +181,6 @@ class ProviderModelSetting(Base): db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -212,7 +202,6 @@ class LoadBalancingModelConfig(Base): db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) diff --git a/api/models/source.py b/api/models/source.py index f6e0900ae6..632af9643b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -16,7 +16,6 @@ class DataSourceOauthBinding(Base): db.Index("source_binding_tenant_id_idx", "tenant_id"), db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) @@ -34,7 +33,6 @@ class DataSourceApiKeyAuthBinding(Base): db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) diff --git a/api/models/task.py b/api/models/task.py index d853c1dd9a..971105dfc5 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -11,7 +11,6 @@ class CeleryTask(Base): """Task result/status.""" __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) @@ -35,7 +34,6 @@ class CeleryTaskSet(Base): """TaskSet result.""" __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 03fbc3acb1..7234de191c 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -28,7 +28,6 @@ class BuiltinToolProvider(Base): # one tenant can only have one tool provider with the same name db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) - # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the tenant @@ -61,7 +60,6 @@ class ApiToolProvider(Base): db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider name = db.Column(db.String(255), nullable=False) @@ -84,7 +82,6 @@ class ApiToolProvider(Base): privacy_policy = db.Column(db.String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -121,7 +118,6 @@ class ToolLabelBinding(Base): db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) @@ -142,7 +138,6 @@ class WorkflowToolProvider(Base): db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider name: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -164,7 +159,6 @@ class WorkflowToolProvider(Base): parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") # privacy policy privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") - created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) @@ -196,7 +190,6 @@ class ToolModelInvoke(Base): __tablename__ = "tool_model_invokes" __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) @@ -214,7 +207,6 @@ class ToolModelInvoke(Base): prompt_messages = db.Column(db.Text, nullable=False) # invoke response model_response = db.Column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) @@ -239,7 +231,6 @@ class ToolConversationVariables(Base): db.Index("user_id_idx", "user_id"), db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) @@ -249,7 +240,6 @@ class ToolConversationVariables(Base): conversation_id = db.Column(StringUUID, nullable=False) # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -268,7 +258,6 @@ class ToolFile(Base): db.PrimaryKeyConstraint("id", name="tool_file_pkey"), db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) @@ -299,11 +288,9 @@ class DeprecatedPublishedAppTool(Base): db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) - user_id: Mapped[str] = db.Column(StringUUID, nullable=False) # who published this tool description = db.Column(db.Text, nullable=False) diff --git a/api/models/types.py b/api/models/types.py index e5581c3ab0..8885e632af 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -35,7 +35,6 @@ _E = TypeVar("_E", bound=enum.StrEnum) class EnumText(TypeDecorator, Generic[_E]): impl = VARCHAR cache_ok = True - _length: int _enum_class: type[_E] diff --git a/api/models/web.py b/api/models/web.py index fe2f0c47f8..41b4769533 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -14,7 +14,6 @@ class SavedMessage(Base): db.PrimaryKeyConstraint("id", name="saved_message_pkey"), db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -33,7 +32,6 @@ class PinnedConversation(Base): db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) diff --git a/api/models/workflow.py b/api/models/workflow.py index 7f01135af3..f9c552d024 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -20,7 +20,6 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode - import sqlalchemy as sa from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column @@ -38,7 +37,6 @@ from .enums import CreatorUserRole, DraftVariableType from .types import EnumText, StringUUID _logger = logging.getLogger(__name__) - if TYPE_CHECKING: from models.model import AppMode @@ -55,7 +53,6 @@ class WorkflowType(Enum): def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -68,7 +65,6 @@ class WorkflowType(Enum): def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. - :param app_mode: app mode :return: workflow type """ @@ -85,30 +81,19 @@ class _InvalidGraphDefinitionError(Exception): class Workflow(Base): """ Workflow, for `Workflow App` and `Chat App workflow mode`. - Attributes: - - id (uuid) Workflow ID, pk - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - type (string) Workflow type - `workflow` for `Workflow App` - `chat` for `Chat App workflow mode` - - version (string) Version - `draft` for draft version (only one for each app), other for version number (redundant) - - graph (text) Workflow canvas configuration (JSON) - The entire canvas configuration JSON, including Node, Edge, and other configurations - - nodes (array[object]) Node list, see Node Schema - - edges (array[object]) Edge list, see Edge Schema - - created_by (uuid) Creator ID - created_at (timestamp) Creation time - updated_by (uuid) `optional` Last updater ID @@ -120,7 +105,6 @@ class Workflow(Base): db.PrimaryKeyConstraint("id", name="workflow_pkey"), db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -145,7 +129,6 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) - VERSION_DRAFT = "draft" @classmethod @@ -217,14 +200,11 @@ class Workflow(Base): the node's id, title, and its data as a dict. """ workflow_graph = self.graph_dict - if not workflow_graph: raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}") - nodes = workflow_graph.get("nodes") if not nodes: raise WorkflowDataError("nodes not found in workflow graph") - try: node_config = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: @@ -264,7 +244,6 @@ class Workflow(Base): """ if not self._features: return self._features - features = json.loads(self._features) if features.get("file_upload", {}).get("image", {}).get("enabled", False): image_enabled = True @@ -295,36 +274,28 @@ class Workflow(Base): # get start node from graph if not self.graph: return [] - graph_dict = self.graph_dict if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) if not start_node: return [] - # get user_input_form from start node variables: list[Any] = start_node.get("data", {}).get("variables", []) - if to_old_structure: old_structure_variables = [] for variable in variables: old_structure_variables.append({variable["type"]: variable}) - return old_structure_variables - return variables @property def unique_hash(self) -> str: """ Get hash of workflow. - :return: hash """ entity = {"graph": self.graph_dict, "features": self.features_dict} - return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property @@ -333,7 +304,6 @@ class Workflow(Base): DEPRECATED: This property is not accurate for determining if a workflow is published as a tool. It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version is the one being used by the tool. - For accurate checking, use a direct query with tenant_id, app_id, and version. """ from models.tools import WorkflowToolProvider @@ -350,7 +320,6 @@ class Workflow(Base): # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" - # Get tenant_id from current_user (Account or EndUser) if isinstance(current_user, Account): # Account user @@ -358,10 +327,8 @@ class Workflow(Base): else: # EndUser tenant_id = current_user.tenant_id - if not tenant_id: return [] - environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) results = [ variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() @@ -382,7 +349,6 @@ class Workflow(Base): if not value: self._environment_variables = "{}" return - # Get tenant_id from current_user (Account or EndUser) if isinstance(current_user, Account): # Account user @@ -390,15 +356,12 @@ class Workflow(Base): else: # EndUser tenant_id = current_user.tenant_id - if not tenant_id: self._environment_variables = "{}" return - value = list(value) if any(var for var in value if not var.id): raise ValueError("environment variable require a unique id") - # Compare inputs and origin variables, # if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). origin_variables_dictionary = {var.id: var for var in self.environment_variables} @@ -426,7 +389,6 @@ class Workflow(Base): v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] - result = { "graph": self.graph_dict, "features": self.features_dict, @@ -440,7 +402,6 @@ class Workflow(Base): # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: self._conversation_variables = "{}" - variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] return results @@ -460,21 +421,15 @@ class Workflow(Base): class WorkflowRun(Base): """ Workflow Run - Attributes: - - id (uuid) Run ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - - workflow_id (uuid) Workflow ID - type (string) Workflow type - triggered_from (string) Trigger source - `debugging` for canvas debugging - `app-run` for (published) app execution - - version (string) Version - graph (text) Workflow canvas configuration (JSON) - inputs (text) Input parameters @@ -485,11 +440,8 @@ class WorkflowRun(Base): - total_tokens (int) `optional` Total tokens used - total_steps (int) Total steps (redundant), default 0 - created_by_role (string) Creator role - - `account` Console account - - `end_user` End user - - created_by (uuid) Runner ID - created_at (timestamp) Run time - finished_at (timestamp) End time @@ -500,11 +452,9 @@ class WorkflowRun(Base): db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - workflow_id: Mapped[str] = mapped_column(StringUUID) type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) @@ -621,21 +571,15 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): class WorkflowNodeExecutionModel(Base): """ Workflow Node Execution - - id (uuid) Execution ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - workflow_id (uuid) Workflow ID - triggered_from (string) Trigger source - `single-step` for single-step debugging - `workflow-run` for workflow execution (debugging / user execution) - - workflow_run_id (uuid) `optional` Workflow run ID - Null for single-step debugging. - - index (int) Execution sequence number, used for displaying Tracing Node order - predecessor_node_id (string) `optional` Predecessor node ID, used for displaying execution path - node_id (string) Node ID @@ -648,20 +592,13 @@ class WorkflowNodeExecutionModel(Base): - error (string) `optional` Error reason - elapsed_time (float) `optional` Time consumption (s) - execution_metadata (text) Metadata - - total_tokens (int) `optional` Total tokens used - - total_price (decimal) `optional` Total cost - - currency (string) `optional` Currency, such as USD / RMB - - created_at (timestamp) Run time - created_by_role (string) Creator role - - `account` Console account - - `end_user` End user - - created_by (uuid) Runner ID - finished_at (timestamp) End time """ @@ -782,7 +719,6 @@ class WorkflowNodeExecutionModel(Base): provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - return extras @@ -799,7 +735,6 @@ class WorkflowAppLogCreatedFrom(Enum): def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. - :param value: mode value :return: mode """ @@ -812,28 +747,19 @@ class WorkflowAppLogCreatedFrom(Enum): class WorkflowAppLog(Base): """ Workflow App execution log, excluding workflow debugging records. - Attributes: - - id (uuid) run ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - workflow_id (uuid) Associated Workflow ID - workflow_run_id (uuid) Associated Workflow Run ID - created_from (string) Creation source - `service-api` App Execution OpenAPI - `web-app` WebApp - `installed-app` Installed App - - created_by_role (string) Creator role - - `account` Console account - - `end_user` End user - - created_by (uuid) Creator ID, depends on the user table according to created_by_role - created_at (timestamp) Creation time """ @@ -843,7 +769,6 @@ class WorkflowAppLog(Base): db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) @@ -873,7 +798,6 @@ class WorkflowAppLog(Base): class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" - id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) @@ -917,10 +841,8 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): """`WorkflowDraftVariable` record variables and outputs generated during debugging worfklow or chatflow. - IMPORTANT: This model maintains multiple invariant rules that must be preserved. Do not instantiate this class directly with the constructor. - Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) defined below to ensure all invariants are properly maintained. """ @@ -937,17 +859,14 @@ class WorkflowDraftVariable(Base): __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),) # Required for instance variable annotation. __allow_unmapped__ = True - # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) - updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, @@ -955,10 +874,8 @@ class WorkflowDraftVariable(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp(), ) - # "`app_id` maps to the `id` field in the `model.App` model." app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - # `last_edited_at` records when the value of a given draft variable # is edited. # @@ -968,7 +885,6 @@ class WorkflowDraftVariable(Base): nullable=True, default=None, ) - # The `node_id` field is special. # # If the variable is a conversation variable or a system variable, then the value of `node_id` @@ -980,7 +896,6 @@ class WorkflowDraftVariable(Base): # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other # "Answer" node conform the rules above.) node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") - # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. # @@ -991,21 +906,15 @@ class WorkflowDraftVariable(Base): default="", nullable=False, ) - selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") - # The data type of this variable's value value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) - # The variable's value serialized as a JSON string value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") - # Controls whether the variable should be displayed in the variable inspection panel visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) - # Determines whether this variable can be modified by users editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) - # The `node_execution_id` field identifies the workflow node execution that created this variable. # It corresponds to the `id` field in the `WorkflowNodeExecutionModel` model. # @@ -1016,7 +925,6 @@ class WorkflowDraftVariable(Base): nullable=True, default=None, ) - # Cache for deserialized value # # NOTE(QuantumGhost): This field serves two purposes: @@ -1035,7 +943,6 @@ class WorkflowDraftVariable(Base): The constructor of `WorkflowDraftVariable` is not intended for direct use outside this file. Its solo purpose is setup private state used by the model instance. - Please use the factory methods (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) defined below to create instances of this class. @@ -1107,25 +1014,19 @@ class WorkflowDraftVariable(Base): raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") file_list = cls.rebuild_file_types(value) return build_segment_with_type(segment_type=segment_type, value=file_list) - return build_segment_with_type(segment_type=segment_type, value=value) def get_value(self) -> Segment: """Decode the serialized value into its corresponding `Segment` object. - This method caches the result, so repeated calls will return the same object instance without re-parsing the serialized data. - If you need to modify the returned `Segment`, use `value.model_copy()` to create a copy first to avoid affecting the cached instance. - For more information about the caching mechanism, see the documentation of the `__value` field. - Returns: Segment: The deserialized value as a Segment object. """ - if self.__value is not None: return self.__value value = self._loads_value() @@ -1138,10 +1039,8 @@ class WorkflowDraftVariable(Base): def set_value(self, value: Segment): """Updates the `value` and corresponding `value_type` fields in the database model. - This method also stores the provided Segment object in the deserialized cache without creating a copy, allowing for efficient value access. - Args: value: The Segment object to store as the variable's value. """ diff --git a/api/pyproject.toml b/api/pyproject.toml index d33806d0ae..4816d53bfa 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -90,6 +90,18 @@ dependencies = [ [tool.setuptools] packages = [] +[tool.ruff] +line-length = 120 +extend-exclude = [ + "**/migrations/versions/*.py", +] + +[tool.ruff.format] +# The following option is used to control whether to add a trailing comma to the last +# item of a multi-line expression. It's set to true to avoid ruff and editorconfig's +# conflict on the final newline. +skip-magic-trailing-comma = true + [tool.uv] default-groups = ["storage", "tools", "vdb"] package = false diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 9efe120b7a..fb8f86515c 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -34,7 +34,6 @@ def clean_embedding_cache_task(): db.session.execute( text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} ) - db.session.commit() else: break diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index d02bc81f33..fbee3a0a32 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -41,7 +41,6 @@ def clean_messages(): .limit(100) .all() ) - except NotFound: break if not messages: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index c0cd42a226..0938600295 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -36,7 +36,6 @@ def clean_unused_datasets_task(): .group_by(Document.dataset_id) .subquery() ) - # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) @@ -49,7 +48,6 @@ def clean_unused_datasets_task(): .group_by(Document.dataset_id) .subquery() ) - # Main query with join and filter stmt = ( select(Dataset) @@ -62,9 +60,7 @@ def clean_unused_datasets_task(): ) .order_by(Dataset.created_at.desc()) ) - datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: break if datasets.items is None or len(datasets.items) == 0: @@ -97,10 +93,8 @@ def clean_unused_datasets_task(): # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) - # update document update_params = {Document.enabled: False} - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) @@ -122,7 +116,6 @@ def clean_unused_datasets_task(): .group_by(Document.dataset_id) .subquery() ) - # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) @@ -135,7 +128,6 @@ def clean_unused_datasets_task(): .group_by(Document.dataset_id) .subquery() ) - # Main query with join and filter stmt = ( select(Dataset) @@ -149,7 +141,6 @@ def clean_unused_datasets_task(): .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: break if datasets.items is None or len(datasets.items) == 0: @@ -174,10 +165,8 @@ def clean_unused_datasets_task(): # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) - # update document update_params = {Document.enabled: False} - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo( diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8a02278de8..9d5e9809aa 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -27,11 +27,9 @@ def create_tidb_serverless_task(): # create tidb serverless iterations_per_thread = 20 create_clusters(iterations_per_thread) - except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) break - end_at = time.perf_counter() click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 5ee813e1de..315e5c68a9 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -18,15 +18,12 @@ from services.feature_service import FeatureService def mail_clean_document_notify_task(): """ Async Send document clean notify mail - Usage: mail_clean_document_notify_task.delay() """ if not mail.is_inited(): return - logging.info(click.style("Start send document clean notify mail", fg="green")) start_at = time.perf_counter() - # send document clean notify mail try: dataset_auto_disable_logs = ( @@ -57,7 +54,6 @@ def mail_clean_document_notify_task(): account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() if not account: continue - dataset_auto_dataset_map = {} # type: ignore for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: @@ -65,7 +61,6 @@ def mail_clean_document_notify_task(): dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( dataset_auto_disable_log.document_id ) - for dataset_id, document_ids in dataset_auto_dataset_map.items(): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: @@ -81,7 +76,6 @@ def mail_clean_document_notify_task(): mail.send( to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content ) - # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: dataset_auto_disable_log.notified = True diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index e3a7021b9d..6dd276fc7d 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -13,13 +13,11 @@ from extensions.ext_mail import mail # Create a dedicated Redis connection (using the same configuration as Celery) celery_broker_url = dify_config.CELERY_BROKER_URL - parsed = urlparse(celery_broker_url) host = parsed.hostname or "localhost" port = parsed.port or 6379 password = parsed.password or None redis_db = parsed.path.strip("/") or "1" # type: ignore - celery_redis = Redis(host=host, port=port, password=password, db=redis_db) @@ -27,12 +25,10 @@ celery_redis = Redis(host=host, port=port, password=password, db=redis_db) def queue_monitor_task(): queue_name = "dataset" threshold = dify_config.QUEUE_MONITOR_THRESHOLD - try: queue_length = celery_redis.llen(f"{queue_name}") logging.info(click.style(f"Start monitor {queue_name}", fg="green")) logging.info(click.style(f"Queue length: {queue_length}", fg="green")) - if queue_length >= threshold: warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}" logging.warning(click.style(warning_msg, fg="red")) @@ -54,7 +50,6 @@ def queue_monitor_task(): ) except Exception as e: logging.exception(click.style("Exception occurred during sending email", fg="red")) - except Exception as e: logging.exception(click.style("Exception occurred during queue monitoring", fg="red")) finally: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index ce4ecb6e7c..54b226801e 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -24,10 +24,8 @@ def update_tidb_serverless_status_task(): return # update tidb serverless status update_clusters(tidb_serverless_list) - except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) - end_at = time.perf_counter() click.echo( click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") diff --git a/api/services/account_service.py b/api/services/account_service.py index 3fdbda48a6..4961d79407 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -103,10 +103,8 @@ class AccountService: account = db.session.query(Account).filter_by(id=user_id).first() if not account: return None - if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") - current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() if current_tenant: account.set_tenant_id(current_tenant.tenant_id) @@ -119,15 +117,12 @@ class AccountService: ) if not available_ta: return None - account.set_tenant_id(available_ta.tenant_id) available_ta.current = True db.session.commit() - if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return cast(Account, account) @staticmethod @@ -140,21 +135,17 @@ class AccountService: "iss": dify_config.EDITION, "sub": "Console API Passport", } - token: str = PassportService().issue(payload) return token @staticmethod def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError() - if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") - if password and invite_token and account.password is None: # if invite_token is valid, set password and password_salt salt = secrets.token_bytes(16) @@ -163,16 +154,12 @@ class AccountService: base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") - if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - return cast(Account, account) @staticmethod @@ -180,14 +167,11 @@ class AccountService: """update account password""" if account.password and not compare_password(password, account.password, account.password_salt): raise CurrentPasswordIncorrectError("Current password is incorrect.") - # may be raised valid_password(new_password) - # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() @@ -210,7 +194,6 @@ class AccountService: from controllers.console.error import AccountNotFound raise AccountNotFound() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): raise AccountRegisterError( description=( @@ -218,29 +201,22 @@ class AccountService: "30 days and is temporarily unavailable for new account registration" ) ) - account = Account() account.email = email account.name = name - if password: # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt password_hashed = hash_password(password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed account.password_salt = base64_salt - account.interface_language = interface_language account.interface_theme = interface_theme - # Set timezone based on language account.timezone = language_timezone_mapping.get(interface_language, "UTC") - db.session.add(account) db.session.commit() return account @@ -253,9 +229,7 @@ class AccountService: account = AccountService.create_account( email=email, name=name, interface_language=interface_language, password=password ) - TenantService.create_owner_tenant_if_not_exist(account=account) - return account @staticmethod @@ -273,9 +247,7 @@ class AccountService: from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError raise EmailCodeAccountDeletionRateLimitExceededError() - send_account_deletion_verification_code.delay(to=email, code=code) - cls.email_code_account_deletion_rate_limiter.increment_rate_limit(email) @staticmethod @@ -283,10 +255,8 @@ class AccountService: token_data = TokenManager.get_token_data(token, "account_deletion") if token_data is None: return False - if token_data["code"] != code: return False - return True @staticmethod @@ -302,7 +272,6 @@ class AccountService: account_integrate: Optional[AccountIntegrate] = ( db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() ) - if account_integrate: # If it exists, update the record account_integrate.open_id = open_id @@ -314,7 +283,6 @@ class AccountService: account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" ) db.session.add(account_integrate) - db.session.commit() logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: @@ -335,7 +303,6 @@ class AccountService: setattr(account, field, value) else: raise AttributeError(f"Invalid field: {field}") - db.session.commit() return account @@ -351,16 +318,12 @@ class AccountService: def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) - if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value db.session.commit() - access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() - AccountService._store_refresh_token(refresh_token, account.id) - return TokenPair(access_token=access_token, refresh_token=refresh_token) @staticmethod @@ -375,18 +338,14 @@ class AccountService: account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) if not account_id: raise ValueError("Invalid refresh token") - account = AccountService.load_user(account_id.decode("utf-8")) if not account: raise ValueError("Invalid account") - # Generate new access token and refresh token new_access_token = AccountService.get_account_jwt_token(account) new_refresh_token = _generate_refresh_token() - AccountService._delete_refresh_token(refresh_token, account.id) AccountService._store_refresh_token(new_refresh_token, account.id) - return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) @staticmethod @@ -403,14 +362,11 @@ class AccountService: account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") - if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError raise PasswordResetRateLimitExceededError() - code, token = cls.generate_reset_password_token(account_email, account) - send_reset_password_mail_task.delay( language=language, to=account_email, @@ -454,7 +410,6 @@ class AccountService: from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError raise EmailCodeLoginRateLimitExceededError() - code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( account=account, email=email, token_type="email_code_login", additional_data={"code": code} @@ -484,14 +439,11 @@ class AccountService: "30 days and is temporarily unavailable for new account registration" ) ) - account = db.session.query(Account).filter(Account.email == email).first() if not account: return None - if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") - return account @staticmethod @@ -509,7 +461,6 @@ class AccountService: count = redis_client.get(key) if count is None: return False - count = int(count) if count > AccountService.LOGIN_MAX_ERROR_LIMITS: return True @@ -535,7 +486,6 @@ class AccountService: count = redis_client.get(key) if count is None: return False - count = int(count) if count > AccountService.FORGOT_PASSWORD_MAX_ERROR_LIMITS: return True @@ -551,39 +501,31 @@ class AccountService: minute_key = f"email_send_ip_limit_minute:{ip_address}" freeze_key = f"email_send_ip_limit_freeze:{ip_address}" hour_limit_key = f"email_send_ip_limit_hour:{ip_address}" - # check ip is frozen if redis_client.get(freeze_key): return True - # check current minute count current_minute_count = redis_client.get(minute_key) if current_minute_count is None: current_minute_count = 0 current_minute_count = int(current_minute_count) - # check current hour count if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: hour_limit_count = redis_client.get(hour_limit_key) if hour_limit_count is None: hour_limit_count = 0 hour_limit_count = int(hour_limit_count) - if hour_limit_count >= 1: redis_client.setex(freeze_key, 60 * 60, 1) return True else: redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes - # add hour limit count redis_client.incr(hour_limit_key) redis_client.expire(hour_limit_key, 60 * 60) - return True - redis_client.setex(minute_key, 60, current_minute_count + 1) redis_client.expire(minute_key, 60) - return False @@ -600,10 +542,8 @@ class TenantService: raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) - db.session.add(tenant) db.session.commit() - tenant.encrypt_public_key = generate_key_pair(tenant.id) db.session.commit() return tenant @@ -619,18 +559,14 @@ class TenantService: .order_by(TenantAccountJoin.id.asc()) .first() ) - if available_ta: return - """Create owner tenant if not exist""" if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: raise WorkSpaceNotAllowedCreateError() - workspaces = FeatureService.get_system_features().license.workspaces if not workspaces.is_available(): raise WorkspacesLimitExceededError() - if name: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: @@ -647,14 +583,12 @@ class TenantService: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): logging.error(f"Tenant {tenant.id} has already an owner.") raise Exception("Tenant already has an owner.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: ta.role = role else: ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) - db.session.commit() return ta @@ -674,7 +608,6 @@ class TenantService: tenant = account.current_tenant if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role @@ -685,11 +618,9 @@ class TenantService: @staticmethod def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None: """Switch the current workspace for the account""" - # Ensure tenant_id is provided if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = ( db.session.query(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) @@ -700,7 +631,6 @@ class TenantService: ) .first() ) - if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: @@ -721,14 +651,11 @@ class TenantService: .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) - # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: account.role = role updated_accounts.append(account) - return updated_accounts @staticmethod @@ -741,14 +668,11 @@ class TenantService: .filter(TenantAccountJoin.tenant_id == tenant.id) .filter(TenantAccountJoin.role == "dataset_operator") ) - # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: account.role = role updated_accounts.append(account) - return updated_accounts @staticmethod @@ -756,7 +680,6 @@ class TenantService: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountRole) for role in roles): raise ValueError("all roles must be TenantAccountRole") - return ( db.session.query(TenantAccountJoin) .filter( @@ -791,13 +714,10 @@ class TenantService: } if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") - if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first() - if not ta_operator or ta_operator.role not in perms[action]: raise NoPermissionError(f"No permission to {action} member.") @@ -806,13 +726,10 @@ class TenantService: """Remove member from tenant""" if operator.id == account.id: raise CannotOperateSelfError("Cannot operate self.") - TenantService.check_member_permission(tenant, operator, account, "remove") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: raise MemberNotInTenantError("Member not in tenant.") - db.session.delete(ta) db.session.commit() @@ -820,17 +737,13 @@ class TenantService: def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = ( db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first() ) - if not target_member_join: raise MemberNotInTenantError("Member not in tenant.") - if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == "owner": # Find the current owner and change their role to 'admin' current_owner_join = ( @@ -838,7 +751,6 @@ class TenantService: ) if current_owner_join: current_owner_join.role = "admin" - # Update the role of the target member target_member_join.role = new_role db.session.commit() @@ -855,7 +767,6 @@ class TenantService: @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) - return cast(dict, tenant.custom_config_dict) @@ -868,7 +779,6 @@ class RegisterService: def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: """ Setup dify - :param email: email :param name: username :param password: password @@ -883,12 +793,9 @@ class RegisterService: password=password, is_setup=True, ) - account.last_login_ip = ip_address account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) - dify_setup = DifySetup(version=dify_config.project.version) db.session.add(dify_setup) db.session.commit() @@ -898,7 +805,6 @@ class RegisterService: db.session.query(Account).delete() db.session.query(Tenant).delete() db.session.commit() - logging.exception(f"Setup account failed, email: {email}, name: {name}") raise ValueError(f"Setup failed: {e}") @@ -927,10 +833,8 @@ class RegisterService: ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if ( FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required @@ -940,7 +844,6 @@ class RegisterService: TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) - db.session.commit() except WorkSpaceNotAllowedCreateError: db.session.rollback() @@ -954,7 +857,6 @@ class RegisterService: db.session.rollback() logging.exception("Register failed") raise AccountRegisterError(f"Registration failed: {e}") from e - return account @classmethod @@ -963,15 +865,12 @@ class RegisterService: ) -> str: if not inviter: raise ValueError("Inviter is required") - """Invite new member""" with Session(db.engine) as session: account = session.query(Account).filter_by(email=email).first() - if not account: TenantService.check_member_permission(tenant, inviter, None, "add") name = email.split("@")[0] - account = cls.register( email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True ) @@ -981,16 +880,12 @@ class RegisterService: else: TenantService.check_member_permission(tenant, inviter, account, "add") ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - if not ta: TenantService.create_tenant_member(tenant, account, role) - # Support resend invitation email when the account is pending status if account.status != AccountStatus.PENDING.value: raise AccountAlreadyInTenantError("Account already in tenant.") - token = cls.generate_invite_token(tenant, account) - # send email send_invite_member_mail_task.delay( language=account.interface_language, @@ -999,7 +894,6 @@ class RegisterService: inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) - return token @classmethod @@ -1035,33 +929,26 @@ class RegisterService: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None - tenant = ( db.session.query(Tenant) .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .first() ) - if not tenant: return None - tenant_account = ( db.session.query(Account, TenantAccountJoin.role) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) .first() ) - if not tenant_account: return None - account = tenant_account[0] if not account: return None - if invitation_data["account_id"] != str(account.id): return None - return { "account": account, "data": invitation_data, @@ -1076,10 +963,8 @@ class RegisterService: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) - if not account_id: return None - return { "account_id": account_id.decode("utf-8"), "email": email, @@ -1089,7 +974,6 @@ class RegisterService: data = redis_client.get(cls._get_invitation_token_key(token)) if not data: return None - invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6dc1affa11..5fd690eb04 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -22,7 +22,6 @@ class AdvancedPromptTemplateService: model_mode = args["model_mode"] model_name = args["model_name"] has_context = args["has_context"] - if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: @@ -31,7 +30,6 @@ class AdvancedPromptTemplateService: @classmethod def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT.value: if model_mode == "completion": return cls.get_completion_prompt( @@ -57,7 +55,6 @@ class AdvancedPromptTemplateService: prompt_template["completion_prompt_config"]["prompt"]["text"] = ( context + prompt_template["completion_prompt_config"]["prompt"]["text"] ) - return prompt_template @classmethod @@ -66,13 +63,11 @@ class AdvancedPromptTemplateService: prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] ) - return prompt_template @classmethod def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT.value: if model_mode == "completion": return cls.get_completion_prompt( diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 503b31ede2..42af919c85 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -22,7 +22,6 @@ class AgentService: """ contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - conversation: Conversation | None = ( db.session.query(Conversation) .filter( @@ -31,10 +30,8 @@ class AgentService: ) .first() ) - if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Optional[Message] = ( db.session.query(Message) .filter( @@ -43,12 +40,9 @@ class AgentService: ) .first() ) - if not message: raise ValueError(f"Message not found: {message_id}") - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts - if conversation.from_end_user_id: # only select name field executor = ( @@ -58,18 +52,14 @@ class AgentService: executor = ( db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() ) - if executor: executor = executor.name else: executor = "Unknown" - timezone = pytz.timezone(current_user.timezone) - app_model_config = app_model.app_model_config if not app_model_config: raise ValueError("App model config not found") - result = { "meta": { "status": "success", @@ -83,11 +73,9 @@ class AgentService: "iterations": [], "files": message.message_files, } - agent_config = AgentConfigManager.convert(app_model_config.to_dict()) if not agent_config: raise ValueError("Agent config not found") - agent_tools = agent_config.tools or [] def find_agent_tool(tool_name: str): @@ -125,7 +113,6 @@ class AgentService: ) else: tool_icon = "" - tool_calls.append( { "status": "success" if not tool_meta_data.get("error") else "error", @@ -139,7 +126,6 @@ class AgentService: "tool_icon": tool_icon, } ) - result["iterations"].append( { "tokens": agent_thought.tokens, @@ -153,7 +139,6 @@ class AgentService: "files": agent_thought.files, } ) - return result @classmethod diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 8c950abc24..699fc646a4 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -29,17 +29,14 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") if args.get("message_id"): message_id = str(args["message_id"]) # get message info message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() - if not message: raise NotFound("Message Not Exists.") - annotation = message.annotation # save the message annotation if annotation: @@ -80,7 +77,6 @@ class AppAnnotationService: cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: return {"job_id": cache_result, "job_status": "processing"} - # async job job_id = str(uuid.uuid4()) enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) @@ -103,7 +99,6 @@ class AppAnnotationService: cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: return {"job_id": cache_result, "job_status": "processing"} - # async job job_id = str(uuid.uuid4()) disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) @@ -120,7 +115,6 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") if keyword: @@ -152,7 +146,6 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") annotations = ( @@ -171,10 +164,8 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation = MessageAnnotation( app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) @@ -202,24 +193,18 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() - if not annotation: raise NotFound("Annotation not found") - annotation.content = args["answer"] annotation.question = args["question"] - db.session.commit() # if annotation reply is enabled , add annotation to index app_annotation_setting = ( db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - if app_annotation_setting: update_annotation_to_index_task.delay( annotation.id, @@ -228,7 +213,6 @@ class AppAnnotationService: app_id, app_annotation_setting.collection_binding_id, ) - return annotation @classmethod @@ -239,17 +223,12 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() - if not annotation: raise NotFound("Annotation not found") - db.session.delete(annotation) - annotation_hit_histories = ( db.session.query(AppAnnotationHitHistory) .filter(AppAnnotationHitHistory.annotation_id == annotation_id) @@ -258,13 +237,11 @@ class AppAnnotationService: if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) - db.session.commit() # if annotation reply is enabled , delete annotation index app_annotation_setting = ( db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - if app_annotation_setting: delete_annotation_index_task.delay( annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id @@ -278,10 +255,8 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - try: # Skip the first row df = pd.read_csv(file) @@ -317,15 +292,11 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() - if not annotation: raise NotFound("Annotation not found") - stmt = ( select(AppAnnotationHitHistory) .filter( @@ -342,7 +313,6 @@ class AppAnnotationService: @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() - if not annotation: return None return annotation @@ -364,7 +334,6 @@ class AppAnnotationService: db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) - annotation_hit_history = AppAnnotationHitHistory( annotation_id=annotation_id, app_id=app_id, @@ -387,10 +356,8 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation_setting = ( db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) @@ -415,10 +382,8 @@ class AppAnnotationService: .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) - if not app: raise NotFound("App not found") - annotation_setting = ( db.session.query(AppAnnotationSetting) .filter( @@ -434,9 +399,7 @@ class AppAnnotationService: annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) db.session.commit() - collection_binding_detail = annotation_setting.collection_binding_detail - return { "id": annotation_setting.id, "enabled": True, diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 601d67d2fb..d1380d3b95 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -13,18 +13,14 @@ class APIBasedExtensionService: .order_by(APIBasedExtension.created_at.desc()) .all() ) - for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) - return extension_list @classmethod def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: cls._validation(extension_data) - extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) - db.session.add(extension_data) db.session.commit() return extension_data @@ -42,12 +38,9 @@ class APIBasedExtensionService: .filter_by(id=api_based_extension_id) .first() ) - if not extension: raise ValueError("API based extension is not found") - extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) - return extension @classmethod @@ -55,7 +48,6 @@ class APIBasedExtensionService: # name if not extension_data.name: raise ValueError("name must not be empty") - if not extension_data.id: # case one: check new data, name must be unique is_name_existed = ( @@ -64,7 +56,6 @@ class APIBasedExtensionService: .filter_by(name=extension_data.name) .first() ) - if is_name_existed: raise ValueError("name must be unique, it is already existed") else: @@ -76,21 +67,16 @@ class APIBasedExtensionService: .filter(APIBasedExtension.id != extension_data.id) .first() ) - if is_name_existed: raise ValueError("name must be unique, it is already existed") - # api_endpoint if not extension_data.api_endpoint: raise ValueError("api_endpoint must not be empty") - # api_key if not extension_data.api_key: raise ValueError("api_key must not be empty") - if len(extension_data.api_key) < 5: raise ValueError("api_key must be at least 5 characters") - # check endpoint cls._ping_connection(extension_data) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..b5f22110a6 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -36,7 +36,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableServic from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) - IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes @@ -77,19 +76,15 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: imported_ver = version.parse(imported_version) except version.InvalidVersion: return ImportStatus.FAILED - # If imported version is newer than current, always return PENDING if imported_ver > current_ver: return ImportStatus.PENDING - # If imported version is older than current's major, return PENDING if imported_ver.major < current_ver.major: return ImportStatus.PENDING - # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS if imported_ver.minor < current_ver.minor: return ImportStatus.COMPLETED_WITH_WARNINGS - # If imported version equals or is older than current's micro, return COMPLETED return ImportStatus.COMPLETED @@ -130,13 +125,11 @@ class AppDslService: ) -> Import: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) - # Validate import mode try: mode = ImportMode(import_mode) except ValueError: raise ValueError(f"Invalid import_mode: {import_mode}") - # Get YAML content content: str = "" if mode == ImportMode.YAML_URL: @@ -158,14 +151,12 @@ class AppDslService: response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response.raise_for_status() content = response.content.decode() - if len(content) > DSL_MAX_SIZE: return Import( id=import_id, status=ImportStatus.FAILED, error="File size exceeds the limit of 10MB", ) - if not content: return Import( id=import_id, @@ -186,7 +177,6 @@ class AppDslService: error="yaml_content is required when import_mode is yaml-content", ) content = yaml_content - # Process YAML content try: # Parse YAML to validate format @@ -197,19 +187,16 @@ class AppDslService: status=ImportStatus.FAILED, error="Invalid YAML format: content must be a mapping", ) - # Validate and fix DSL version if not data.get("version"): data["version"] = "0.1.0" if not data.get("kind") or data.get("kind") != "app": data["kind"] = "app" - imported_version = data.get("version", "0.1.0") # check if imported_version is a float-like string if not isinstance(imported_version, str): raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) - # Extract app data app_data = data.get("app") if not app_data: @@ -218,27 +205,23 @@ class AppDslService: status=ImportStatus.FAILED, error="Missing app data in YAML content", ) - # If app_id is provided, check if it exists app = None if app_id: stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id) app = self._session.scalar(stmt) - if not app: return Import( id=import_id, status=ImportStatus.FAILED, error="App not found", ) - if app.mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: return Import( id=import_id, status=ImportStatus.FAILED, error="Only workflow or advanced chat apps can be overwritten", ) - # If major version mismatch, store import info in Redis if status == ImportStatus.PENDING: pending_data = PendingData( @@ -256,14 +239,12 @@ class AppDslService: IMPORT_INFO_REDIS_EXPIRY, pending_data.model_dump_json(), ) - return Import( id=import_id, status=status, app_id=app_id, imported_dsl_version=imported_version, ) - # Extract dependencies dependencies = data.get("dependencies", []) check_dependencies_pending_data = None @@ -275,11 +256,9 @@ class AppDslService: dependencies_list = self._extract_dependencies_from_workflow_graph(graph) else: dependencies_list = self._extract_dependencies_from_model_config(data.get("model_config", {})) - check_dependencies_pending_data = DependenciesAnalysisService.generate_latest_dependencies( dependencies_list ) - # Create or update app app = self._create_or_update_app( app=app, @@ -292,7 +271,6 @@ class AppDslService: icon_background=icon_background, dependencies=check_dependencies_pending_data, ) - draft_var_srv = WorkflowDraftVariableService(session=self._session) draft_var_srv.delete_workflow_variables(app_id=app.id) return Import( @@ -302,14 +280,12 @@ class AppDslService: app_mode=app.mode, imported_dsl_version=imported_version, ) - except yaml.YAMLError as e: return Import( id=import_id, status=ImportStatus.FAILED, error=f"Invalid YAML format: {str(e)}", ) - except Exception as e: logger.exception("Failed to import app") return Import( @@ -324,14 +300,12 @@ class AppDslService: """ redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" pending_data = redis_client.get(redis_key) - if not pending_data: return Import( id=import_id, status=ImportStatus.FAILED, error="Import information expired or does not exist", ) - try: if not isinstance(pending_data, str | bytes): return Import( @@ -341,12 +315,10 @@ class AppDslService: ) pending_data = PendingData.model_validate_json(pending_data) data = yaml.safe_load(pending_data.yaml_content) - app = None if pending_data.app_id: stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id) app = self._session.scalar(stmt) - # Create or update app app = self._create_or_update_app( app=app, @@ -358,10 +330,8 @@ class AppDslService: icon=pending_data.icon, icon_background=pending_data.icon_background, ) - # Delete import info from Redis redis_client.delete(redis_key) - return Import( id=import_id, status=ImportStatus.COMPLETED, @@ -370,7 +340,6 @@ class AppDslService: current_dsl_version=CURRENT_DSL_VERSION, imported_dsl_version=data.get("version", "0.1.0"), ) - except Exception as e: logger.exception("Error confirming import") return Import( @@ -390,10 +359,8 @@ class AppDslService: dependencies = redis_client.get(redis_key) if not dependencies: return CheckDependenciesResult() - # Extract dependencies dependencies = CheckDependenciesPendingData.model_validate_json(dependencies) - # Get leaked dependencies leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies( tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies @@ -421,7 +388,6 @@ class AppDslService: if not app_mode: raise ValueError("loss app mode") app_mode = AppMode(app_mode) - # Set icon type icon_type_value = icon_type or app_data.get("icon_type") if icon_type_value in ["emoji", "link", "image"]: @@ -429,7 +395,6 @@ class AppDslService: else: icon_type = "emoji" icon = icon or str(app_data.get("icon", "")) - if app: # Update existing app app.name = name or app_data.get("name", app.name) @@ -441,7 +406,6 @@ class AppDslService: else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") - # Create new app app = App() app.id = str(uuid4()) @@ -457,11 +421,9 @@ class AppDslService: app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) app.created_by = account.id app.updated_by = account.id - self._session.add(app) self._session.commit() app_was_created.send(app, account=account) - # save dependencies if dependencies: redis_client.setex( @@ -469,13 +431,11 @@ class AppDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(), ) - # Initialize app based on mode if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow_data = data.get("workflow") if not workflow_data or not isinstance(workflow_data, dict): raise ValueError("Missing workflow data for workflow/advanced chat app") - environment_variables_list = workflow_data.get("environment_variables", []) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list @@ -484,7 +444,6 @@ class AppDslService: conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - workflow_service = WorkflowService() current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) if current_draft_workflow: @@ -521,9 +480,7 @@ class AppDslService: app_model_config.app_id = app.id app_model_config.created_by = account.id app_model_config.updated_by = account.id - app.app_model_config_id = app_model_config.id - self._session.add(app_model_config) app_model_config_was_updated.send(app, app_model_config=app_model_config) else: @@ -539,7 +496,6 @@ class AppDslService: :return: """ app_mode = AppMode.value_of(app_model.mode) - export_data = { "version": CURRENT_DSL_VERSION, "kind": "app", @@ -552,14 +508,12 @@ class AppDslService: "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, }, } - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret ) else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) # type: ignore @classmethod @@ -573,7 +527,6 @@ class AppDslService: workflow = workflow_service.get_draft_workflow(app_model) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - workflow_dict = workflow.to_dict(include_secret=include_secret) for node in workflow_dict.get("graph", {}).get("nodes", []): if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: @@ -601,7 +554,6 @@ class AppDslService: app_model_config = app_model.app_model_config if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data["model_config"] = app_model_config.to_dict() dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) export_data["dependencies"] = [ @@ -698,7 +650,6 @@ class AppDslService: pass except Exception as e: logger.exception("Error extracting node dependency", exc_info=e) - return dependencies @classmethod @@ -709,7 +660,6 @@ class AppDslService: :return: dependencies list format like ["langgenius/google"] """ dependencies = [] - try: # completion model model_dict = model_config.get("model", {}) @@ -717,7 +667,6 @@ class AppDslService: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", "")) ) - # reranking model dataset_configs = model_config.get("dataset_configs", {}) if dataset_configs: @@ -730,7 +679,6 @@ class AppDslService: .get("provider") ) ) - # tools agent_configs = model_config.get("agent_mode", {}) if agent_configs: @@ -738,10 +686,8 @@ class AppDslService: dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id")) ) - except Exception as e: logger.exception("Error extracting model config dependency", exc_info=e) - return dependencies @classmethod @@ -752,7 +698,6 @@ class AppDslService: dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] if not dependencies: return [] - return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) @staticmethod diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 245c123a04..3d60a20302 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -51,7 +51,6 @@ class AppGenerateService: f"or your RPD was {dify_config.APP_DAILY_RATE_LIMIT} requests/day" ) cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id) - # app level rate limiter max_active_request = AppGenerateService._get_max_active_requests(app_model) rate_limit = RateLimit(app_model.id, max_active_request) @@ -207,14 +206,11 @@ class AppGenerateService: if invoke_from == InvokeFrom.DEBUGGER: # fetch draft workflow by app_model workflow = workflow_service.get_draft_workflow(app_model=app_model) - if not workflow: raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) - if not workflow: raise ValueError("Workflow not published") - return workflow diff --git a/api/services/app_service.py b/api/services/app_service.py index d08462d001..58a7febf10 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -36,7 +36,6 @@ class AppService: :return: """ filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args["mode"] == "workflow": filters.append(App.mode == AppMode.WORKFLOW.value) elif args["mode"] == "completion": @@ -49,7 +48,6 @@ class AppService: filters.append(App.mode == AppMode.AGENT_CHAT.value) elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): @@ -61,14 +59,12 @@ class AppService: filters.append(App.id.in_(target_ids)) else: return None - app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), page=args["page"], per_page=args["limit"], error_out=False, ) - return app_models def create_app(self, tenant_id: str, args: dict, account: Account) -> App: @@ -80,14 +76,12 @@ class AppService: """ app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] - # get model config default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() - # get default model instance try: model_instance = model_manager.get_default_model_instance( @@ -98,7 +92,6 @@ class AppService: except Exception as e: logging.exception(f"Get default model instance failed, tenant_id: {tenant_id}") model_instance = None - if model_instance: if ( model_instance.model == default_model_config["model"]["name"] @@ -110,7 +103,6 @@ class AppService: model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if model_schema is None: raise ValueError(f"model schema not found for model {model_instance.model}") - default_model_dict = { "provider": model_instance.provider, "name": model_instance.model, @@ -124,9 +116,7 @@ class AppService: default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] - default_model_config["model"] = json.dumps(default_model_dict) - app = App(**app_template["app"]) app.name = args["name"] app.description = args.get("description", "") @@ -139,10 +129,8 @@ class AppService: app.api_rpm = args.get("api_rpm", 0) app.created_by = account.id app.updated_by = account.id - db.session.add(app) db.session.flush() - if default_model_config: app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id @@ -150,17 +138,12 @@ class AppService: app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.flush() - app.app_model_config_id = app_model_config.id - db.session.commit() - app_was_created.send(app, account=account) - if FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private") - return app def get_app(self, app: App) -> App: @@ -190,19 +173,16 @@ class AppService: provider_type=agent_tool_entity.provider_type, identity_id=f"AGENT.{app.id}", ) - # get decrypted parameters if agent_tool_entity.tool_parameters: parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) masked_parameter = manager.mask_tool_parameters(parameters or {}) else: masked_parameter = {} - # override tool parameters tool["tool_parameters"] = masked_parameter except Exception as e: pass - # override agent mode model_config.agent_mode = json.dumps(agent_mode) @@ -219,7 +199,6 @@ class AppService: return model_config app = ModifiedApp(app) - return app def update_app(self, app: App, args: dict) -> App: @@ -238,7 +217,6 @@ class AppService: app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return app def update_app_name(self, app: App, name: str) -> App: @@ -252,7 +230,6 @@ class AppService: app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -268,7 +245,6 @@ class AppService: app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -280,12 +256,10 @@ class AppService: """ if enable_site == app.enable_site: return app - app.enable_site = enable_site app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -297,12 +271,10 @@ class AppService: """ if enable_api == app.enable_api: return app - app.enable_api = enable_api app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return app def delete_app(self, app: App) -> None: @@ -312,11 +284,9 @@ class AppService: """ db.session.delete(app) db.session.commit() - # clean up web app settings if FeatureService.get_system_features().webapp_auth.enabled: EnterpriseService.WebAppAuth.cleanup_webapp(app.id) - # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) @@ -327,14 +297,11 @@ class AppService: :return: """ app_mode = AppMode.value_of(app_model.mode) - meta: dict = {"tool_icons": {}} - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta - graph = workflow.graph_dict nodes = graph.get("nodes", []) tools = [] @@ -351,17 +318,12 @@ class AppService: ) else: app_model_config: Optional[AppModelConfig] = app_model.app_model_config - if not app_model_config: return meta - agent_config = app_model_config.agent_mode_dict - # get all tools tools = agent_config.get("tools", []) - url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" - for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: @@ -381,7 +343,6 @@ class AppService: meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} - return meta @staticmethod diff --git a/api/services/audio_service.py b/api/services/audio_service.py index e8923eb51b..90ec8848df 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -24,7 +24,6 @@ from services.workflow_service import WorkflowService FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 - logger = logging.getLogger(__name__) @@ -35,40 +34,31 @@ class AudioService: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") - features_dict = workflow.features_dict if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") - if file is None: raise NoAudioUploadedServiceError() - extension = file.mimetype if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() - file_content = file.read() file_size = len(file_content) - if file_size > FILE_SIZE_LIMIT: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() - buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod @@ -97,19 +87,15 @@ class AudioService: or not workflow.features_dict["text_to_speech"].get("enabled") ): raise ValueError("TTS is not enabled") - voice = workflow.features_dict["text_to_speech"].get("voice") else: if not is_draft: if app_model.app_model_config is None: raise ValueError("AppModelConfig not found") text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get("voice") - model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS @@ -123,7 +109,6 @@ class AudioService: raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) @@ -140,7 +125,6 @@ class AudioService: return None if message.answer == "" and message.status == MessageStatus.NORMAL: return None - else: response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft) if isinstance(response, Generator): @@ -160,7 +144,6 @@ class AudioService: model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() - try: return model_instance.get_tts_voices(language) except Exception as e: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index e5f4a3ef6e..5ca7189aa6 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -23,7 +23,6 @@ class ApiKeyAuthService: # Encrypt the api key api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) args["credentials"]["config"]["api_key"] = api_key - data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id data_source_api_key_binding.category = args["category"] diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index 6ef034f292..0c708b55b7 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -13,7 +13,6 @@ class FirecrawlAuth(ApiKeyAuthBase): raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") self.api_key = credentials.get("config", {}).get("api_key", None) self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") - if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index 6100e9afc8..e253937558 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -12,7 +12,6 @@ class JinaAuth(ApiKeyAuthBase): if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") self.api_key = credentials.get("config", {}).get("api_key", None) - if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index 6100e9afc8..e253937558 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -12,7 +12,6 @@ class JinaAuth(ApiKeyAuthBase): if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") self.api_key = credentials.get("config", {}).get("api_key", None) - if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index 153ab5ba75..4d95dcbeac 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -14,7 +14,6 @@ class WatercrawlAuth(ApiKeyAuthBase): raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key") self.api_key = credentials.get("config", {}).get("api_key", None) self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev") - if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d44483ad89..8b5af56455 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -12,22 +12,18 @@ from models.account import Account, TenantAccountJoin, TenantAccountRole class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") - compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60) @classmethod def get_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} - billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod def get_knowledge_rate_limit(cls, tenant_id: str): params = {"tenant_id": tenant_id} - knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) - return { "limit": knowledge_rate_limit.get("limit", 10), "subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"), @@ -62,7 +58,6 @@ class BillingService: ) def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} - url = f"{cls.base_url}{endpoint}" response = httpx.request(method, url, json=json, params=params, headers=headers) if method == "GET" and response.status_code != httpx.codes.OK: @@ -72,16 +67,13 @@ class BillingService: @staticmethod def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) - if not join: raise ValueError("Tenant account join not found") - if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") @@ -116,9 +108,7 @@ class BillingService: from controllers.console.error import EducationVerifyLimitError raise EducationVerifyLimitError() - cls.verification_rate_limit.increment_rate_limit(account_email) - params = {"account_id": account_id} return BillingService._send_request("GET", "/education/verify", params=params) @@ -133,7 +123,6 @@ class BillingService: from controllers.console.error import EducationActivateLimitError raise EducationActivateLimitError() - cls.activation_rate_limit.increment_rate_limit(account.email) params = {"account_id": account.id, "curr_tenant_id": account.current_tenant_id} json = { @@ -162,7 +151,6 @@ class BillingService: from controllers.console.error import CompilanceRateLimitError raise CompilanceRateLimitError() - json = { "doc_name": doc_name, "account_id": account_id, diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d581..1a8224df51 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -39,7 +39,6 @@ class ClearFreePlanTenantExpiredLogs: ) if len(messages) == 0: break - storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}" @@ -50,22 +49,17 @@ class ClearFreePlanTenantExpiredLogs: ), ).encode("utf-8"), ) - message_ids = [message.id for message in messages] - # delete messages session.query(Message).filter( Message.id.in_(message_ids), ).delete(synchronize_session=False) - session.commit() - click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} " ) ) - while True: with Session(db.engine).no_autoflush as session: conversations = ( @@ -77,10 +71,8 @@ class ClearFreePlanTenantExpiredLogs: .limit(batch) .all() ) - if len(conversations) == 0: break - storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}" @@ -91,20 +83,17 @@ class ClearFreePlanTenantExpiredLogs: ), ).encode("utf-8"), ) - conversation_ids = [conversation.id for conversation in conversations] session.query(Conversation).filter( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) session.commit() - click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}" f" conversations for tenant {tenant_id}" ) ) - while True: with Session(db.engine).no_autoflush as session: workflow_node_executions = ( @@ -117,10 +106,8 @@ class ClearFreePlanTenantExpiredLogs: .limit(batch) .all() ) - if len(workflow_node_executions) == 0: break - # save workflow node executions storage.save( f"free_plan_tenant_expired_logs/" @@ -130,24 +117,20 @@ class ClearFreePlanTenantExpiredLogs: jsonable_encoder(workflow_node_executions), ).encode("utf-8"), ) - workflow_node_execution_ids = [ workflow_node_execution.id for workflow_node_execution in workflow_node_executions ] - # delete workflow node executions session.query(WorkflowNodeExecutionModel).filter( WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), ).delete(synchronize_session=False) session.commit() - click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" f" workflow node executions for tenant {tenant_id}" ) ) - while True: with Session(db.engine).no_autoflush as session: workflow_runs = ( @@ -159,12 +142,9 @@ class ClearFreePlanTenantExpiredLogs: .limit(batch) .all() ) - if len(workflow_runs) == 0: break - # save workflow runs - storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" @@ -175,9 +155,7 @@ class ClearFreePlanTenantExpiredLogs: ), ).encode("utf-8"), ) - workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] - # delete workflow runs session.query(WorkflowRun).filter( WorkflowRun.id.in_(workflow_run_ids), @@ -189,19 +167,14 @@ class ClearFreePlanTenantExpiredLogs: """ Clear free plan tenant expired logs. """ - click.echo(click.style("Clearing free plan tenant expired logs", fg="white")) ended_at = datetime.datetime.now() started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) current_time = started_at - with Session(db.engine) as session: total_tenant_count = session.query(Tenant.id).count() - click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) - handled_tenant_count = 0 - thread_pool = ThreadPoolExecutor(max_workers=10) def process_tenant(flask_app: Flask, tenant_id: str) -> None: @@ -229,7 +202,6 @@ class ClearFreePlanTenantExpiredLogs: ) futures = [] - if tenant_ids: for tenant_id in tenant_ids: futures.append( @@ -257,7 +229,6 @@ class ClearFreePlanTenantExpiredLogs: datetime.timedelta(hours=3), datetime.timedelta(hours=1), ] - for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) @@ -270,7 +241,6 @@ class ClearFreePlanTenantExpiredLogs: else: # If all intervals have too many tenants, use minimum interval interval = datetime.timedelta(hours=1) - # Adjust interval to target ~100 tenants per batch if tenant_count > 0: # Scale interval based on ratio to target count @@ -281,15 +251,12 @@ class ClearFreePlanTenantExpiredLogs: interval * (100 / tenant_count), # Scale to target 100 ), ) - batch_end = min(current_time + interval, ended_at) - rs = ( session.query(Tenant.id) .filter(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) - tenants = [] for row in rs: tenant_id = str(row.id) @@ -298,7 +265,6 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception(f"Failed to process tenant {tenant_id}") continue - futures.append( thread_pool.submit( process_tenant, @@ -306,9 +272,7 @@ class ClearFreePlanTenantExpiredLogs: tenant_id, ) ) - current_time = batch_end - # wait for all threads to finish for future in futures: future.result() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index afdaa49465..4189c9db8f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -37,7 +37,6 @@ class ConversationService: ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, @@ -50,15 +49,12 @@ class ConversationService: stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: stmt = stmt.where(~Conversation.id.in_(exclude_ids)) - # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) - if last_id: last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() - # build filters based on sorting filter_condition = cls._build_filter_condition( sort_field=sort_field, @@ -68,7 +64,6 @@ class ConversationService: stmt = stmt.where(filter_condition) query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) conversations = session.scalars(query_stmt).all() - has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] @@ -81,7 +76,6 @@ class ConversationService: rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True - return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod @@ -108,14 +102,12 @@ class ConversationService: auto_generate: bool, ): conversation = cls.get_conversation(app_model, conversation_id, user) - if auto_generate: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return conversation @classmethod @@ -127,10 +119,8 @@ class ConversationService: .order_by(Message.created_at.asc()) .first() ) - if not message: raise MessageNotExistsError() - # generate conversation name try: name = LLMGenerator.generate_conversation_name( @@ -139,9 +129,7 @@ class ConversationService: conversation.name = name except: pass - db.session.commit() - return conversation @classmethod @@ -158,16 +146,13 @@ class ConversationService: ) .first() ) - if not conversation: raise ConversationNotExistsError() - return conversation @classmethod def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): conversation = cls.get_conversation(app_model, conversation_id, user) - conversation.is_deleted = True conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -182,32 +167,26 @@ class ConversationService: last_id: Optional[str], ) -> InfiniteScrollPagination: conversation = cls.get_conversation(app_model, conversation_id, user) - stmt = ( select(ConversationVariable) .where(ConversationVariable.app_id == app_model.id) .where(ConversationVariable.conversation_id == conversation.id) .order_by(ConversationVariable.created_at) ) - with Session(db.engine) as session: if last_id: last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id)) if not last_variable: raise ConversationVariableNotExistsError() - # Filter for variables created after the last_id stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at) - # Apply limit to query query_stmt = stmt.limit(limit) # Get one extra to check if there are more rows = session.scalars(query_stmt).all() - has_more = False if len(rows) > limit: has_more = True rows = rows[:limit] # Remove the extra item - variables = [ { "created_at": row.created_at, @@ -216,5 +195,4 @@ class ConversationService: } for row in rows ] - return InfiniteScrollPagination(variables, limit, has_more) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e42b5ace75..d199226929 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -80,14 +80,12 @@ class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) - if user: # get permitted dataset ids dataset_permission = ( db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() ) permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None - if user.current_role == TenantAccountRole.DATASET_OPERATOR: # only show datasets that the user has permission to access if permitted_dataset_ids: @@ -122,19 +120,15 @@ class DatasetService: else: # if no user, only show datasets that are shared with all team members query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) - if search: query = query.filter(Dataset.name.ilike(f"%{search}%")) - if tag_ids: target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) - return datasets.items, datasets.total @staticmethod @@ -158,9 +152,7 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) - datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) - return datasets.items, datasets.total @staticmethod @@ -221,7 +213,6 @@ class DatasetService: dataset.provider = provider db.session.add(dataset) db.session.flush() - if provider == "external" and external_knowledge_api_id: external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) if not external_knowledge_api: @@ -234,7 +225,6 @@ class DatasetService: created_by=account.id, ) db.session.add(external_knowledge_binding) - db.session.commit() return dataset @@ -299,15 +289,12 @@ class DatasetService: def update_dataset(dataset_id, data, user): """ Update dataset configuration and settings. - Args: dataset_id: The unique identifier of the dataset to update data: Dictionary containing the update data user: The user performing the update operation - Returns: Dataset: The updated dataset object - Raises: ValueError: If dataset not found or validation fails NoPermissionError: If user lacks permission to update the dataset @@ -316,10 +303,8 @@ class DatasetService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") - # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) - # Handle external dataset updates if dataset.provider == "external": return DatasetService._update_external_dataset(dataset, data, user) @@ -330,12 +315,10 @@ class DatasetService: def _update_external_dataset(dataset, data, user): """ Update external dataset configuration. - Args: dataset: The dataset object to update data: Update data dictionary user: User performing the update - Returns: Dataset: Updated dataset object """ @@ -343,20 +326,16 @@ class DatasetService: external_retrieval_model = data.get("external_retrieval_model", None) if external_retrieval_model: dataset.retrieval_model = external_retrieval_model - # Update basic dataset properties dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", dataset.description) - # Update permission if provided permission = data.get("permission") if permission: dataset.permission = permission - # Validate and update external knowledge configuration external_knowledge_id = data.get("external_knowledge_id", None) external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_id: raise ValueError("External knowledge id is required.") if not external_knowledge_api_id: @@ -365,20 +344,16 @@ class DatasetService: dataset.updated_by = user.id if user else None dataset.updated_at = datetime.datetime.utcnow() db.session.add(dataset) - # Update external knowledge binding DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) - # Commit changes to database db.session.commit() - return dataset @staticmethod def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): """ Update external knowledge binding configuration. - Args: dataset_id: Dataset identifier external_knowledge_id: External knowledge identifier @@ -388,10 +363,8 @@ class DatasetService: external_knowledge_binding = ( session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() ) - if not external_knowledge_binding: raise ValueError("External knowledge binding not found.") - # Update binding if values have changed if ( external_knowledge_binding.external_knowledge_id != external_knowledge_id @@ -405,12 +378,10 @@ class DatasetService: def _update_internal_dataset(dataset, data, user): """ Update internal dataset configuration. - Args: dataset: The dataset object to update data: Update data dictionary user: User performing the update - Returns: Dataset: Updated dataset object """ @@ -419,39 +390,31 @@ class DatasetService: data.pop("external_knowledge_api_id", None) data.pop("external_knowledge_id", None) data.pop("external_retrieval_model", None) - # Filter out None values except for description field filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - # Handle indexing technique changes and embedding model updates action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) - # Add metadata fields filtered_data["updated_by"] = user.id filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - # Update dataset in database db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() - # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) - return dataset @staticmethod def _handle_indexing_technique_change(dataset, data, filtered_data): """ Handle changes in indexing technique and configure embedding models accordingly. - Args: dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data - Returns: str: Action to perform ('add', 'remove', 'update', or None) """ @@ -475,7 +438,6 @@ class DatasetService: def _configure_embedding_model_for_high_quality(data, filtered_data): """ Configure embedding model settings for high quality indexing. - Args: data: Update data dictionary filtered_data: Filtered update data to modify @@ -505,12 +467,10 @@ class DatasetService: def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): """ Handle embedding model updates when indexing technique remains the same. - Args: dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data to modify - Returns: str: Action to perform ('update' or None) """ @@ -530,7 +490,6 @@ class DatasetService: def _preserve_existing_embedding_settings(dataset, filtered_data): """ Preserve existing embedding model settings when not provided in update. - Args: dataset: Current dataset object filtered_data: Filtered update data to modify @@ -553,12 +512,10 @@ class DatasetService: def _update_embedding_model_settings(dataset, data, filtered_data): """ Update embedding model settings with new values. - Args: dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data to modify - Returns: str: Action to perform ('update' or None) """ @@ -570,7 +527,6 @@ class DatasetService: new_provider_str = ( str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None ) - # Only update if values are different if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) @@ -587,7 +543,6 @@ class DatasetService: def _apply_new_embedding_settings(dataset, data, filtered_data): """ Apply new embedding model settings to the dataset. - Args: dataset: Current dataset object data: Update data dictionary @@ -614,7 +569,6 @@ class DatasetService: filtered_data["collection_binding_id"] = dataset.collection_binding_id # Skip the rest of the embedding model update return - # Apply new embedding model settings filtered_data["embedding_model"] = embedding_model.model filtered_data["embedding_model_provider"] = embedding_model.provider @@ -626,14 +580,10 @@ class DatasetService: @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) - if dataset is None: return False - DatasetService.check_dataset_permission(dataset, user) - dataset_was_deleted.send(dataset) - db.session.delete(dataset) db.session.commit() return True @@ -668,15 +618,12 @@ class DatasetService: def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): if not dataset: raise ValueError("Dataset not found") - if not user: raise ValueError("User not found") - if user.current_role != TenantAccountRole.OWNER: if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: raise NoPermissionError("You do not have permission to access this dataset.") - elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id @@ -687,9 +634,7 @@ class DatasetService: @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) - dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) - return dataset_queries.items, dataset_queries.total @staticmethod @@ -744,7 +689,6 @@ class DocumentService: "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH, }, } - DOCUMENT_METADATA_SCHEMA: dict[str, Any] = { "book": { "title": str, @@ -851,7 +795,6 @@ class DocumentService: @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: document = db.session.query(Document).filter(Document.id == document_id).first() - return document @staticmethod @@ -878,7 +821,6 @@ class DocumentService: ) .all() ) - return documents @staticmethod @@ -893,7 +835,6 @@ class DocumentService: ) .all() ) - return documents @staticmethod @@ -916,7 +857,6 @@ class DocumentService: ) .all() ) - return documents @staticmethod @@ -943,7 +883,6 @@ class DocumentService: document_was_deleted.send( document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id ) - db.session.delete(document) db.session.commit() @@ -956,7 +895,6 @@ class DocumentService: if document.data_source_type == "upload_file" ] batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) - for document in documents: db.session.delete(document) db.session.commit() @@ -966,25 +904,19 @@ class DocumentService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found.") - document = DocumentService.get_document(dataset_id, document_id) - if not document: raise ValueError("Document not found.") - if document.tenant_id != current_user.current_tenant_id: raise ValueError("No permission.") - if dataset.built_in_field_enabled: if document.doc_metadata: doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata[BuiltInField.document_name.value] = name document.doc_metadata = doc_metadata - document.name = name db.session.add(document) db.session.commit() - return document @staticmethod @@ -995,7 +927,6 @@ class DocumentService: document.is_paused = True document.paused_by = current_user.id document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.add(document) db.session.commit() # set document paused flag @@ -1010,7 +941,6 @@ class DocumentService: document.is_paused = False document.paused_by = None document.paused_at = None - db.session.add(document) db.session.commit() # delete paused flag @@ -1031,7 +961,6 @@ class DocumentService: document.indexing_status = "waiting" db.session.add(document) db.session.commit() - redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task document_ids = [document.id for document in documents] @@ -1051,9 +980,7 @@ class DocumentService: document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() - redis_client.setex(sync_indexing_cache_key, 600, 1) - sync_website_document_indexing_task.delay(dataset_id, document.id) @staticmethod @@ -1076,7 +1003,6 @@ class DocumentService: ): # check document limit features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: if not knowledge_config.original_document_id: count = 0 @@ -1092,22 +1018,17 @@ class DocumentService: website_info = knowledge_config.data_source.info_list.website_info_list count = len(website_info.urls) # type: ignore batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if features.billing.subscription.plan == "sandbox" and count > 1: raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) - # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore - if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() @@ -1134,13 +1055,11 @@ class DocumentService: "top_k": 2, "score_threshold_enabled": False, } - dataset.retrieval_model = ( knowledge_config.retrieval_model.model_dump() if knowledge_config.retrieval_model else default_retrieval_model ) # type: ignore - documents = [] if knowledge_config.original_document_id: document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) @@ -1191,11 +1110,9 @@ class DocumentService: .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) - # raise error if file not found if not file: raise FileNotExistsError() - file_name = file.name data_source_info = { "upload_file_id": file_id, @@ -1350,13 +1267,11 @@ class DocumentService: documents.append(document) position += 1 db.session.commit() - # trigger async task if document_ids: document_indexing_task.delay(dataset.id, document_ids) if duplicate_document_ids: duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch @staticmethod @@ -1471,11 +1386,9 @@ class DocumentService: .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) - # raise error if file not found if not file: raise FileNotExistsError() - file_name = file.name data_source_info = { "upload_file_id": file_id, @@ -1522,7 +1435,6 @@ class DocumentService: document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name - # update document name if document_data.name: document.name = document_data.name @@ -1549,7 +1461,6 @@ class DocumentService: @staticmethod def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: count = 0 if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore @@ -1573,9 +1484,7 @@ class DocumentService: batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) - dataset_collection_binding_id = None retrieval_model = None if knowledge_config.indexing_technique == "high_quality": @@ -1606,18 +1515,14 @@ class DocumentService: collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) # type: ignore db.session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - cut_length = 18 cut_name = documents[0].name[:cut_length] dataset.name = cut_name + "..." dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() - return dataset, documents, batch @classmethod @@ -1634,13 +1539,10 @@ class DocumentService: def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): if not knowledge_config.data_source: raise ValueError("Data source is required") - if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if not knowledge_config.data_source.info_list: raise ValueError("Data source info is required") - if knowledge_config.data_source.info_list.data_source_type == "upload_file": if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") @@ -1655,50 +1557,37 @@ class DocumentService: def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): if not knowledge_config.process_rule: raise ValueError("Process rule is required") - if not knowledge_config.process_rule.mode: raise ValueError("Process rule mode is required") - if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if knowledge_config.process_rule.mode == "automatic": knowledge_config.process_rule.rules = None else: if not knowledge_config.process_rule.rules: raise ValueError("Process rule rules is required") - if knowledge_config.process_rule.rules.pre_processing_rules is None: raise ValueError("Process rule pre_processing_rules is required") - unique_pre_processing_rule_dicts = {} for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: if not pre_processing_rule.id: raise ValueError("Process rule pre_processing_rules id is required") - if not isinstance(pre_processing_rule.enabled, bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule - knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) - if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") - if not knowledge_config.process_rule.rules.segmentation.separator: raise ValueError("Process rule segmentation separator is required") - if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): raise ValueError("Process rule segmentation separator is invalid") - if not ( knowledge_config.process_rule.mode == "hierarchical" and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): if not knowledge_config.process_rule.rules.segmentation.max_tokens: raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): raise ValueError("Process rule segmentation max_tokens is invalid") @@ -1706,82 +1595,61 @@ class DocumentService: def estimate_args_validate(cls, args: dict): if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": args["process_rule"]["rules"] = {} else: if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if ( "pre_processing_rules" not in args["process_rule"]["rules"] or args["process_rule"]["rules"]["pre_processing_rules"] is None ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") - unique_pre_processing_rule_dicts = {} for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if ( "segmentation" not in args["process_rule"]["rules"] or args["process_rule"]["rules"]["segmentation"] is None ): raise ValueError("Process rule segmentation is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if ( "separator" not in args["process_rule"]["rules"]["segmentation"] or not args["process_rule"]["rules"]["segmentation"]["separator"] ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if ( "max_tokens" not in args["process_rule"]["rules"]["segmentation"] or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @@ -1789,57 +1657,46 @@ class DocumentService: def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): """ Batch update document status. - Args: dataset (Dataset): The dataset object document_ids (list[str]): List of document IDs to update action (str): Action to perform (enable, disable, archive, un_archive) user: Current user performing the action - Raises: DocumentIndexingError: If document is being indexed or not in correct state ValueError: If action is invalid """ if not document_ids: return - # Early validation of action parameter valid_actions = ["enable", "disable", "archive", "un_archive"] if action not in valid_actions: raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}") - documents_to_update = [] - # First pass: validate all documents and prepare updates for document_id in document_ids: document = DocumentService.get_document(dataset.id, document_id) if not document: continue - # Check if document is being indexed indexing_cache_key = f"document_{document.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") - # Prepare update based on action update_info = DocumentService._prepare_document_status_update(document, action, user) if update_info: documents_to_update.append(update_info) - # Second pass: apply all updates in a single transaction if documents_to_update: try: for update_info in documents_to_update: document = update_info["document"] updates = update_info["updates"] - # Apply updates to the document for field, value in updates.items(): setattr(document, field, value) - db.session.add(document) - # Batch commit all changes db.session.commit() except Exception as e: @@ -1879,17 +1736,14 @@ class DocumentService: def _prepare_document_status_update(document, action: str, user): """ Prepare document status update information. - Args: document: Document object to update action: Action to perform user: Current user - Returns: dict: Update information or None if no update needed """ now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - if action == "enable": return DocumentService._prepare_enable_update(document, now) elif action == "disable": @@ -1898,7 +1752,6 @@ class DocumentService: return DocumentService._prepare_archive_update(document, user, now) elif action == "un_archive": return DocumentService._prepare_unarchive_update(document, now) - return None @staticmethod @@ -1906,7 +1759,6 @@ class DocumentService: """Prepare updates for enabling a document.""" if document.enabled: return None - return { "document": document, "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now}, @@ -1919,10 +1771,8 @@ class DocumentService: """Prepare updates for disabling a document.""" if not document.completed_at or document.indexing_status != "completed": raise DocumentIndexingError(f"Document: {document.name} is not completed.") - if not document.enabled: return None - return { "document": document, "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now}, @@ -1935,19 +1785,16 @@ class DocumentService: """Prepare updates for archiving a document.""" if document.archived: return None - update_info = { "document": document, "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now}, "async_task": None, "set_cache": False, } - # Only set async task and cache if document is currently enabled if document.enabled: update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]} update_info["set_cache"] = True - return update_info @staticmethod @@ -1955,19 +1802,16 @@ class DocumentService: """Prepare updates for unarchiving a document.""" if not document.archived: return None - update_info = { "document": document, "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now}, "async_task": None, "set_cache": False, } - # Only re-index if the document is currently enabled if document.enabled: update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]} update_info["set_cache"] = True - return update_info @@ -2023,13 +1867,11 @@ class SegmentService: if document.doc_form == "qa_model": segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) # update document word count document.word_count += segment_document.word_count db.session.add(document) db.session.commit() - # save vector index try: VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) @@ -2079,7 +1921,6 @@ class SegmentService: )[0] else: tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -2103,7 +1944,6 @@ class SegmentService: db.session.add(segment_document) segment_data_list.append(segment_document) position += 1 - pre_segment_data_list.append(segment_document) if "keywords" in segment_item: keywords_list.append(segment_item["keywords"]) @@ -2180,7 +2020,6 @@ class SegmentService: if dataset.indexing_technique == "high_quality": # check embedding model setting model_manager = ModelManager() - if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( tenant_id=dataset.tenant_id, @@ -2221,7 +2060,6 @@ class SegmentService: model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - # calc embedding use tokens if document.doc_form == "qa_model": segment.answer = args.answer @@ -2255,7 +2093,6 @@ class SegmentService: if dataset.indexing_technique == "high_quality": # check embedding model setting model_manager = ModelManager() - if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( tenant_id=dataset.tenant_id, @@ -2284,7 +2121,6 @@ class SegmentService: elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - except Exception as e: logging.exception("update segment index failed") segment.enabled = False @@ -2301,7 +2137,6 @@ class SegmentService: cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") - # enabled segment need to delete index if segment.enabled: # send delete segment index task @@ -2327,7 +2162,6 @@ class SegmentService: .all() ) index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @@ -2359,7 +2193,6 @@ class SegmentService: db.session.add(segment) real_deal_segmment_ids.append(segment.id) db.session.commit() - enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) elif action == "disable": segments = ( @@ -2386,7 +2219,6 @@ class SegmentService: db.session.add(segment) real_deal_segmment_ids.append(segment.id) db.session.commit() - disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) else: raise InvalidActionError() @@ -2441,7 +2273,6 @@ class SegmentService: db.session.rollback() raise ChildChunkIndexingError(str(e)) db.session.commit() - return child_chunk @classmethod @@ -2462,9 +2293,7 @@ class SegmentService: .all() ) child_chunks_map = {chunk.id: chunk for chunk in child_chunks} - new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] - for child_chunk_update_args in child_chunks_update_args: if child_chunk_update_args.id: child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) @@ -2483,7 +2312,6 @@ class SegmentService: try: if update_child_chunks: db.session.bulk_save_objects(update_child_chunks) - if delete_child_chunks: for child_chunk in delete_child_chunks: db.session.delete(child_chunk) @@ -2505,7 +2333,6 @@ class SegmentService: type="customized", created_by=current_user.id, ) - db.session.add(child_chunk) db.session.flush() new_child_chunks.append(child_chunk) @@ -2594,16 +2421,12 @@ class SegmentService: query = select(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) - if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) - if keyword: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) - query = query.order_by(DocumentSegment.position.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) - return paginated_segments.items, paginated_segments.total @classmethod @@ -2615,15 +2438,12 @@ class SegmentService: dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") - # check user's model setting DatasetService.check_dataset_model_setting(dataset) - # check document document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - # check embedding model setting if high quality if dataset.indexing_technique == "high_quality": try: @@ -2640,7 +2460,6 @@ class SegmentService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - # check segment segment = ( db.session.query(DocumentSegment) @@ -2649,11 +2468,9 @@ class SegmentService: ) if not segment: raise NotFound("Segment not found.") - # validate and update segment cls.segment_create_args_validate(segment_data, document) updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) - return updated_segment, document @classmethod @@ -2682,7 +2499,6 @@ class DatasetCollectionBindingService: .order_by(DatasetCollectionBinding.created_at) .first() ) - if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, @@ -2708,7 +2524,6 @@ class DatasetCollectionBindingService: ) if not dataset_collection_binding: raise ValueError("Dataset collection binding not found") - return dataset_collection_binding @@ -2722,11 +2537,9 @@ class DatasetPermissionService: .filter(DatasetPermission.dataset_id == dataset_id) .all() ) - user_list = [] for user in user_list_query: user_list.append(user.account_id) - return user_list @classmethod @@ -2741,7 +2554,6 @@ class DatasetPermissionService: account_id=user["user_id"], ) permissions.append(permission) - db.session.add_all(permissions) db.session.commit() except Exception as e: @@ -2752,14 +2564,11 @@ class DatasetPermissionService: def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: raise NoPermissionError("User does not have permission to edit this dataset.") - if user.is_dataset_operator and dataset.permission != requested_permission: raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: raise ValueError("Partial member list is required when setting to partial members.") - local_member_list = cls.get_dataset_partial_member_list(dataset.id) request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 3c3f970444..2615e335fc 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -6,7 +6,6 @@ import requests class EnterpriseRequest: base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") - proxies = { "http": "", "https": "", diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 54d45f45ea..9ae8f540ea 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -49,7 +49,6 @@ class EnterpriseService: def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): params = {"userId": user_id, "appCode": app_code} data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) - return data.get("result", False) @classmethod @@ -70,16 +69,13 @@ class EnterpriseService: data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body) if not data: raise ValueError("No data found.") - if not isinstance(data["accessModes"], dict): raise ValueError("Invalid data format.") - ret = {} for key, value in data["accessModes"].items(): curr = WebAppSettings() curr.access_mode = value ret[key] = curr - return ret @classmethod @@ -98,17 +94,13 @@ class EnterpriseService: raise ValueError("app_id must be provided.") if access_mode not in ["public", "private", "private_all"]: raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") - data = {"appId": app_id, "accessMode": access_mode} - response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) - return response.get("result", False) @classmethod def cleanup_webapp(cls, app_id: str): if not app_id: raise ValueError("app_id must be provided.") - body = {"appId": app_id} EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index bc385b2e22..e9d9b45b11 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -68,13 +68,11 @@ class ProviderResponse(BaseModel): preferred_provider_type: ProviderType custom_configuration: CustomConfigurationResponse system_configuration: SystemConfigurationResponse - # pydantic configs model_config = ConfigDict(protected_namespaces=()) def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -82,7 +80,6 @@ class ProviderResponse(BaseModel): self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) - if self.icon_large is not None: self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" @@ -104,7 +101,6 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -112,7 +108,6 @@ class ProviderWithModelsResponse(BaseModel): self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) - if self.icon_large is not None: self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" @@ -128,7 +123,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -136,7 +130,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) - if self.icon_large is not None: self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" @@ -151,7 +144,6 @@ class DefaultModelResponse(BaseModel): model: str model_type: ModelType provider: SimpleProviderEntityResponse - # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index eb50d79494..3428fe3e50 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,11 +35,9 @@ class ExternalDatasetService: ) if search: query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) - external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False ) - return external_knowledge_apis.items, external_knowledge_apis.total @classmethod @@ -65,7 +63,6 @@ class ExternalDatasetService: description=args.get("description", ""), settings=json.dumps(args.get("settings"), ensure_ascii=False), ) - db.session.add(external_knowledge_api) db.session.commit() return external_knowledge_api @@ -76,10 +73,8 @@ class ExternalDatasetService: raise ValueError("endpoint is required") if "api_key" not in settings or not settings["api_key"]: raise ValueError("api_key is required") - endpoint = f"{settings['endpoint']}/retrieval" api_key = settings["api_key"] - parsed_url = urlparse(endpoint) if not all([parsed_url.scheme, parsed_url.netloc]): if not endpoint.startswith("http://") and not endpoint.startswith("https://"): @@ -115,14 +110,12 @@ class ExternalDatasetService: raise ValueError("api template not found") if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") - external_knowledge_api.name = args.get("name") external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return external_knowledge_api @staticmethod @@ -132,7 +125,6 @@ class ExternalDatasetService: ) if external_knowledge_api is None: raise ValueError("api template not found") - db.session.delete(external_knowledge_api) db.session.commit() @@ -178,13 +170,11 @@ class ExternalDatasetService: """ do http request depending on api bundle """ - kwargs = { "url": settings.url, "headers": settings.headers, "follow_redirects": True, } - response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( data=json.dumps(settings.params), files=files, **kwargs ) @@ -200,20 +190,16 @@ class ExternalDatasetService: if authorization.type == "api-key": if authorization.config is None: raise ValueError("authorization config is required") - if authorization.config.api_key is None: raise ValueError("api_key is required") - if not authorization.config.header: authorization.config.header = "Authorization" - if authorization.config.type == "bearer": headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" elif authorization.config.type == "basic": headers[authorization.config.header] = f"Basic {authorization.config.api_key}" elif authorization.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key - return headers @staticmethod @@ -230,10 +216,8 @@ class ExternalDatasetService: .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id) .first() ) - if external_knowledge_api is None: raise ValueError("api template not found") - dataset = Dataset( tenant_id=tenant_id, name=args.get("name"), @@ -242,10 +226,8 @@ class ExternalDatasetService: retrieval_model=args.get("external_retrieval_model"), created_by=user_id, ) - db.session.add(dataset) db.session.flush() - external_knowledge_binding = ExternalKnowledgeBindings( tenant_id=tenant_id, dataset_id=dataset.id, @@ -254,9 +236,7 @@ class ExternalDatasetService: created_by=user_id, ) db.session.add(external_knowledge_binding) - db.session.commit() - return dataset @staticmethod @@ -272,7 +252,6 @@ class ExternalDatasetService: ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") - external_knowledge_api = ( db.session.query(ExternalKnowledgeApis) .filter_by(id=external_knowledge_binding.external_knowledge_api_id) @@ -280,7 +259,6 @@ class ExternalDatasetService: ) if not external_knowledge_api: raise ValueError("external api template not found") - settings = json.loads(external_knowledge_api.settings) headers = {"Content-Type": "application/json"} if settings.get("api_key"): @@ -296,7 +274,6 @@ class ExternalDatasetService: "knowledge_id": external_knowledge_binding.external_knowledge_id, "metadata_condition": metadata_condition.model_dump() if metadata_condition else None, } - response = ExternalDatasetService.process_external_api( ExternalKnowledgeApiSetting( url=f"{settings.get('endpoint')}/retrieval", diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 188caf3505..7918e9ed15 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -41,7 +41,6 @@ class LicenseLimitationModel(BaseModel): def is_available(self, required: int = 1) -> bool: """ Determine whether the requested amount can be allocated. - Returns True if: - this limit is not active, or - the limit is zero (unlimited), or @@ -49,7 +48,6 @@ class LicenseLimitationModel(BaseModel): """ if not self.enabled or self.limit == 0: return True - return (self.limit - self.size) >= required @@ -102,7 +100,6 @@ class PluginInstallationPermissionModel(BaseModel): # official_and_specific_partners: allow official and specific partner plugins # all: allow installation of all plugins plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL - # If True, restrict plugin installation to the marketplace only # Equivalent to ForceEnablePluginVerification restrict_to_marketplace_only: bool = False @@ -123,7 +120,6 @@ class FeatureModel(BaseModel): dataset_operator_enabled: bool = False webapp_copyright_enabled: bool = False workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) - # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -155,16 +151,12 @@ class FeatureService: @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() - cls._fulfill_params_from_env(features) - if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) - if dify_config.ENTERPRISE_ENABLED: features.webapp_copyright_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) - return features @classmethod @@ -180,17 +172,13 @@ class FeatureService: @classmethod def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() - cls._fulfill_system_params_from_env(system_features) - if dify_config.ENTERPRISE_ENABLED: system_features.branding.enabled = True system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) - if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True - return system_features @classmethod @@ -220,75 +208,56 @@ class FeatureService: @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] features.billing.subscription.interval = billing_info["subscription"]["interval"] features.education.activated = billing_info["subscription"].get("education", False) - if features.billing.subscription.plan != "sandbox": features.webapp_copyright_enabled = True - if "members" in billing_info: features.members.size = billing_info["members"]["size"] features.members.limit = billing_info["members"]["limit"] - if "apps" in billing_info: features.apps.size = billing_info["apps"]["size"] features.apps.limit = billing_info["apps"]["limit"] - if "vector_space" in billing_info: features.vector_space.size = billing_info["vector_space"]["size"] features.vector_space.limit = billing_info["vector_space"]["limit"] - if "documents_upload_quota" in billing_info: features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if "annotation_quota_limit" in billing_info: features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if "docs_processing" in billing_info: features.docs_processing = billing_info["docs_processing"] - if "can_replace_logo" in billing_info: features.can_replace_logo = billing_info["can_replace_logo"] - if "model_load_balancing_enabled" in billing_info: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] - if "knowledge_rate_limit" in billing_info: features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info() - if "SSOEnforcedForSignin" in enterprise_info: features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] - if "SSOEnforcedForSigninProtocol" in enterprise_info: features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"] - if "EnableEmailCodeLogin" in enterprise_info: features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] - if "EnableEmailPasswordLogin" in enterprise_info: features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"] - if "IsAllowRegister" in enterprise_info: features.is_allow_register = enterprise_info["IsAllowRegister"] - if "IsAllowCreateWorkspace" in enterprise_info: features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"] - if "Branding" in enterprise_info: features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "") features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "") features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") features.branding.favicon = enterprise_info["Branding"].get("favicon", "") - if "WebAppAuth" in enterprise_info: features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False) features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( @@ -298,21 +267,16 @@ class FeatureService: "allowEmailPasswordLogin", False ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "License" in enterprise_info: license_info = enterprise_info["License"] - if "status" in license_info: features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - if "expiredAt" in license_info: features.license.expired_at = license_info["expiredAt"] - if "workspaces" in license_info: features.license.workspaces.enabled = license_info["workspaces"]["enabled"] features.license.workspaces.limit = license_info["workspaces"]["limit"] features.license.workspaces.size = license_info["workspaces"]["used"] - if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[ diff --git a/api/services/file_service.py b/api/services/file_service.py index 2d68f30c5a..b822a53c70 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -40,38 +40,28 @@ class FileService: ) -> UploadFile: # get file extension extension = os.path.splitext(filename)[1].lstrip(".").lower() - # check if filename contains invalid characters if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]): raise ValueError("Filename contains invalid characters") - if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension - if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() - # get file size file_size = len(content) - # check if the file size is exceeded if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): raise FileTooLargeError - # generate file key file_uuid = str(uuid.uuid4()) - if isinstance(user, Account): current_tenant_id = user.current_tenant_id else: # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension - # save file to storage storage.save(file_key, content) - # save file to db upload_file = UploadFile( tenant_id=current_tenant_id or "", @@ -88,15 +78,12 @@ class FileService: hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, ) - db.session.add(upload_file) db.session.commit() - if not upload_file.source_url: upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) db.session.add(upload_file) db.session.commit() - return upload_file @staticmethod @@ -109,7 +96,6 @@ class FileService: file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 else: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 - return file_size <= file_size_limit @staticmethod @@ -119,10 +105,8 @@ class FileService: # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" - # save file to storage storage.save(file_key, text.encode("utf-8")) - # save file to db upload_file = UploadFile( tenant_id=current_user.current_tenant_id, @@ -139,27 +123,21 @@ class FileService: used_by=current_user.id, used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) - db.session.add(upload_file) db.session.commit() - return upload_file @staticmethod def get_file_preview(file_id: str): upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() - if not upload_file: raise NotFound("File not found") - # extract text from file extension = upload_file.extension if extension.lower() not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() - text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) text = text[0:PREVIEW_WORDS_LIMIT] if text else "" - return text @staticmethod @@ -169,19 +147,14 @@ class FileService: ) if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() - if not upload_file: raise NotFound("File not found or signature is invalid") - # extract text from file extension = upload_file.extension if extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() - generator = storage.load(upload_file.key, stream=True) - return generator, upload_file.mime_type @staticmethod @@ -189,28 +162,20 @@ class FileService: result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() - if not upload_file: raise NotFound("File not found or signature is invalid") - generator = storage.load(upload_file.key, stream=True) - return generator, upload_file @staticmethod def get_public_image_preview(file_id: str): upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() - if not upload_file: raise NotFound("File not found or signature is invalid") - # extract text from file extension = upload_file.extension if extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() - generator = storage.load(upload_file.key) - return generator, upload_file.mime_type diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 519d5abca5..c1878d9eff 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -33,7 +33,6 @@ class HitTestingService: limit: int = 10, ) -> dict: start = time.perf_counter() - # get retrieval model , if the model is not setting , using default if not retrieval_model: retrieval_model = dataset.retrieval_model or default_retrieval_model @@ -41,11 +40,9 @@ class HitTestingService: metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) if metadata_filtering_conditions: dataset_retrieval = DatasetRetrieval() - from core.app.app_config.entities import MetadataFilteringCondition metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) - metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], query=query, @@ -75,17 +72,13 @@ class HitTestingService: weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, ) - end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") - dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) - db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(query, all_documents) # type: ignore @classmethod @@ -102,32 +95,25 @@ class HitTestingService: "query": {"content": query}, "records": [], } - start = time.perf_counter() - all_documents = RetrievalService.external_retrieve( dataset_id=dataset.id, query=cls.escape_query_for_search(query), external_retrieval_model=external_retrieval_model, metadata_filtering_conditions=metadata_filtering_conditions, ) - end = time.perf_counter() logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") - dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) - db.session.add(dataset_query) db.session.commit() - return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]: records = RetrievalService.format_retrieval_documents(documents) - return { "query": { "content": query, @@ -156,7 +142,6 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): query = args["query"] - if not query or len(query) > 250: raise ValueError("Query is required and cannot exceed 250 characters") diff --git a/api/services/message_service.py b/api/services/message_service.py index 51b070ece7..4c81736be5 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -37,26 +37,20 @@ class MessageService: ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - if not conversation_id: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - conversation = ConversationService.get_conversation( app_model=app_model, user=user, conversation_id=conversation_id ) - fetch_limit = limit + 1 - if first_id: first_message = ( db.session.query(Message) .filter(Message.conversation_id == conversation.id, Message.id == first_id) .first() ) - if not first_message: raise FirstMessageNotExistsError() - history_messages = ( db.session.query(Message) .filter( @@ -76,15 +70,12 @@ class MessageService: .limit(fetch_limit) .all() ) - has_more = False if len(history_messages) > limit: has_more = True history_messages = history_messages[:-1] - if order == "asc": history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod @@ -99,27 +90,19 @@ class MessageService: ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Message) - fetch_limit = limit + 1 - if conversation_id is not None: conversation = ConversationService.get_conversation( app_model=app_model, user=user, conversation_id=conversation_id ) - base_query = base_query.filter(Message.conversation_id == conversation.id) - if include_ids is not None: base_query = base_query.filter(Message.id.in_(include_ids)) - if last_id: last_message = base_query.filter(Message.id == last_id).first() - if not last_message: raise LastMessageNotExistsError() - history_messages = ( base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) .order_by(Message.created_at.desc()) @@ -128,12 +111,10 @@ class MessageService: ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all() - has_more = False if len(history_messages) > limit: has_more = True history_messages = history_messages[:-1] - return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod @@ -148,11 +129,8 @@ class MessageService: ): if not user: raise ValueError("user cannot be None") - message = cls.get_message(app_model=app_model, user=user, message_id=message_id) - feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback - if not rating and feedback: db.session.delete(feedback) elif rating and feedback: @@ -172,9 +150,7 @@ class MessageService: from_account_id=(user.id if isinstance(user, Account) else None), ) db.session.add(feedback) - db.session.commit() - return feedback @classmethod @@ -189,7 +165,6 @@ class MessageService: .offset(offset) .all() ) - return [record.to_dict() for record in feedbacks] @classmethod @@ -205,10 +180,8 @@ class MessageService: ) .first() ) - if not message: raise MessageNotExistsError() - return message @classmethod @@ -217,30 +190,22 @@ class MessageService: ) -> list[Message]: if not user: raise ValueError("user cannot be None") - message = cls.get_message(app_model=app_model, user=user, message_id=message_id) - conversation = ConversationService.get_conversation( app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() - if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow_service = WorkflowService() if invoke_from == InvokeFrom.DEBUGGER: workflow = workflow_service.get_draft_workflow(app_model=app_model) else: workflow = workflow_service.get_published_workflow(app_model=app_model) - if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) - if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) @@ -259,35 +224,28 @@ class MessageService: id=conversation.app_model_config_id, app_id=app_model.id, ) - app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) if not app_model_config: raise ValueError("did not find app model config") - suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, model=app_model_config.model_dict["name"], ) - # get memory of conversation (read-only) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - histories = memory.get_history_prompt_text( max_token_limit=3000, message_limit=3, ) - with measure_time() as timer: questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) - # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( @@ -295,5 +253,4 @@ class MessageService: TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) - return questions diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index cfcb121153..2cf38965cb 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -22,7 +22,6 @@ class MetadataService: # check if metadata name is too long if len(metadata_args.name) > 255: raise ValueError("Metadata name cannot exceed 255 characters.") - # check if metadata name already exists if ( db.session.query(DatasetMetadata) @@ -49,7 +48,6 @@ class MetadataService: # check if metadata name is too long if len(name) > 255: raise ValueError("Metadata name cannot exceed 255 characters.") - lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists if ( @@ -70,7 +68,6 @@ class MetadataService: metadata.name = name metadata.updated_by = current_user.id metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - # update related documents dataset_metadata_bindings = ( db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() @@ -100,7 +97,6 @@ class MetadataService: if metadata is None: raise ValueError("Metadata not found.") db.session.delete(metadata) - # deal related documents dataset_metadata_bindings = ( db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 26311a6377..3e026102d1 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -29,7 +29,6 @@ class ModelLoadBalancingService: def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ enable model load balancing. - :param tenant_id: workspace id :param provider: provider name :param model: model name @@ -38,19 +37,16 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Enable model load balancing provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ disable model load balancing. - :param tenant_id: workspace id :param provider: provider name :param model: model name @@ -59,12 +55,10 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # disable model load balancing provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) @@ -81,25 +75,20 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Convert model type to ModelType model_type_enum = ModelType.value_of(model_type) - # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( model_type=model_type_enum, model=model, ) - is_load_balancing_enabled = False if provider_model_setting and provider_model_setting.load_balancing_enabled: is_load_balancing_enabled = True - # Get load balancing configurations load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) @@ -112,7 +101,6 @@ class ModelLoadBalancingService: .order_by(LoadBalancingModelConfig.created_at) .all() ) - if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials @@ -121,11 +109,9 @@ class ModelLoadBalancingService: if load_balancing_config.name == "__inherit__": inherit_config_exists = True break - if not inherit_config_exists: # Initialize the inherit configuration inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) - # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) else: @@ -134,13 +120,10 @@ class ModelLoadBalancingService: if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) - # Get credential form schemas from model credential schema or provider credential schema credential_schemas = self._get_credential_schema(provider_configuration) - # Get decoding rsa key and cipher for decrypting credentials decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - # fetch status and ttl for each config datas = [] for load_balancing_config in load_balancing_configs: @@ -151,7 +134,6 @@ class ModelLoadBalancingService: model_type=model_type_enum, config_id=load_balancing_config.id, ) - try: if load_balancing_config.encrypted_config: credentials = json.loads(load_balancing_config.encrypted_config) @@ -159,12 +141,10 @@ class ModelLoadBalancingService: credentials = {} except JSONDecodeError: credentials = {} - # Get provider credential secret variables credential_secret_variables = provider_configuration.extract_secret_variables( credential_schemas.credential_form_schemas ) - # decrypt credentials for variable in credential_secret_variables: if variable in credentials: @@ -174,12 +154,10 @@ class ModelLoadBalancingService: ) except ValueError: pass - # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append( { "id": load_balancing_config.id, @@ -190,7 +168,6 @@ class ModelLoadBalancingService: "ttl": ttl, } ) - return is_load_balancing_enabled, datas def get_load_balancing_config( @@ -207,15 +184,12 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Convert model type to ModelType model_type_enum = ModelType.value_of(model_type) - # Get load balancing configurations load_balancing_model_config = ( db.session.query(LoadBalancingModelConfig) @@ -228,10 +202,8 @@ class ModelLoadBalancingService: ) .first() ) - if not load_balancing_model_config: return None - try: if load_balancing_model_config.encrypted_config: credentials = json.loads(load_balancing_model_config.encrypted_config) @@ -239,15 +211,12 @@ class ModelLoadBalancingService: credentials = {} except JSONDecodeError: credentials = {} - # Get credential form schemas from model credential schema or provider credential schema credential_schemas = self._get_credential_schema(provider_configuration) - # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - return { "id": load_balancing_model_config.id, "name": load_balancing_model_config.name, @@ -276,7 +245,6 @@ class ModelLoadBalancingService: ) db.session.add(inherit_config) db.session.commit() - return inherit_config def update_load_balancing_configs( @@ -293,18 +261,14 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Convert model type to ModelType model_type_enum = ModelType.value_of(model_type) - if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) .filter( @@ -315,46 +279,34 @@ class ModelLoadBalancingService: ) .all() ) - # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} updated_config_ids = set() - for config in configs: if not isinstance(config, dict): raise ValueError("Invalid load balancing config") - config_id = config.get("id") name = config.get("name") credentials = config.get("credentials") enabled = config.get("enabled") - if not name: raise ValueError("Invalid load balancing config name") - if enabled is None: raise ValueError("Invalid load balancing config enabled") - # is config exists if config_id: config_id = str(config_id) - if config_id not in current_load_balancing_configs_dict: raise ValueError("Invalid load balancing config id: {}".format(config_id)) - updated_config_ids.add(config_id) - load_balancing_config = current_load_balancing_configs_dict[config_id] - # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: raise ValueError("Load balancing config name {} already exists".format(name)) - if credentials: if not isinstance(credentials, dict): raise ValueError("Invalid load balancing config credentials") - # validate custom provider config credentials = self._custom_credentials_validate( tenant_id=tenant_id, @@ -365,32 +317,25 @@ class ModelLoadBalancingService: load_balancing_model_config=load_balancing_config, validate=False, ) - # update load balancing config load_balancing_config.encrypted_config = json.dumps(credentials) - load_balancing_config.name = name load_balancing_config.enabled = enabled load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() - self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config if name == "__inherit__": raise ValueError("Invalid load balancing config name") - # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: raise ValueError("Load balancing config name {} already exists".format(name)) - if not credentials: raise ValueError("Invalid load balancing config credentials") - if not isinstance(credentials, dict): raise ValueError("Invalid load balancing config credentials") - # validate custom provider config credentials = self._custom_credentials_validate( tenant_id=tenant_id, @@ -400,7 +345,6 @@ class ModelLoadBalancingService: credentials=credentials, validate=False, ) - # create load balancing config load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, @@ -410,16 +354,13 @@ class ModelLoadBalancingService: name=name, encrypted_config=json.dumps(credentials), ) - db.session.add(load_balancing_model_config) db.session.commit() - # get deleted config ids deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids for config_id in deleted_config_ids: db.session.delete(current_load_balancing_configs_dict[config_id]) db.session.commit() - self._clear_credentials_cache(tenant_id, config_id) def validate_load_balancing_credentials( @@ -443,15 +384,12 @@ class ModelLoadBalancingService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Convert model type to ModelType model_type_enum = ModelType.value_of(model_type) - load_balancing_model_config = None if config_id: # Get load balancing config @@ -466,10 +404,8 @@ class ModelLoadBalancingService: ) .first() ) - if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") - # Validate custom provider config self._custom_credentials_validate( tenant_id=tenant_id, @@ -503,12 +439,10 @@ class ModelLoadBalancingService: """ # Get credential form schemas from model credential schema or provider credential schema credential_schemas = self._get_credential_schema(provider_configuration) - # Get provider credential secret variables provider_credential_secret_variables = provider_configuration.extract_secret_variables( credential_schemas.credential_form_schemas ) - if load_balancing_model_config: try: # fix origin data @@ -518,14 +452,12 @@ class ModelLoadBalancingService: original_credentials = {} except JSONDecodeError: original_credentials = {} - # encrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) - if validate: model_provider_factory = ModelProviderFactory(tenant_id) if isinstance(credential_schemas, ModelCredentialSchema): @@ -539,11 +471,9 @@ class ModelLoadBalancingService: credentials = model_provider_factory.provider_credentials_validate( provider=provider_configuration.provider.provider, credentials=credentials ) - for key, value in credentials.items(): if key in provider_credential_secret_variables: credentials[key] = encrypter.encrypt_token(tenant_id, value) - return credentials def _get_credential_schema( @@ -567,5 +497,4 @@ class ModelLoadBalancingService: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) - provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0a0a5619e1..e62cd1b756 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -31,21 +31,18 @@ class ModelProviderService: def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: """ get provider list. - :param tenant_id: workspace id :param model_type: model type :return: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - provider_responses = [] for provider_configuration in provider_configurations.values(): if model_type: model_type_entity = ModelType.value_of(model_type) if model_type_entity not in provider_configuration.provider.supported_model_types: continue - provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, @@ -71,9 +68,7 @@ class ModelProviderService: quota_configurations=provider_configuration.system_configuration.quota_configurations, ), ) - provider_responses.append(provider_response) - return provider_responses def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]: @@ -81,14 +76,12 @@ class ModelProviderService: get provider models. For the model provider page, only supports passing in a single provider to query the list of supported models. - :param tenant_id: :param provider: :return: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider available models return [ ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model) @@ -103,31 +96,26 @@ class ModelProviderService: provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: """ validate provider credentials. - :param tenant_id: :param provider: :param credentials: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - provider_configuration.custom_credentials_validate(credentials) def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: """ save custom provider config. - :param tenant_id: workspace id :param provider: provider name :param credentials: provider credentials @@ -135,38 +123,32 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Add or update custom provider credentials. provider_configuration.add_or_update_custom_credentials(credentials) def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: """ remove custom provider config. - :param tenant_id: workspace id :param provider: provider name :return: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Remove custom provider credentials. provider_configuration.delete_custom_credentials() def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: """ get model credentials. - :param tenant_id: workspace id :param provider: provider name :param model_type: model type @@ -175,12 +157,10 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( model_type=ModelType.value_of(model_type), model=model, obfuscated=True @@ -191,7 +171,6 @@ class ModelProviderService: ) -> None: """ validate model credentials. - :param tenant_id: workspace id :param provider: provider name :param model_type: model type @@ -201,12 +180,10 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Validate model credentials provider_configuration.custom_model_credentials_validate( model_type=ModelType.value_of(model_type), model=model, credentials=credentials @@ -217,7 +194,6 @@ class ModelProviderService: ) -> None: """ save model credentials. - :param tenant_id: workspace id :param provider: provider name :param model_type: model type @@ -227,12 +203,10 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( model_type=ModelType.value_of(model_type), model=model, credentials=credentials @@ -241,7 +215,6 @@ class ModelProviderService: def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: """ remove model credentials. - :param tenant_id: workspace id :param provider: provider name :param model_type: model type @@ -250,51 +223,40 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Remove custom model credentials provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ get models by model type. - :param tenant_id: workspace id :param model_type: model type :return: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) - # Group models by provider provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] - if model.deprecated: continue - if model.status != ModelStatus.ACTIVE: continue - provider_models[model.provider.provider].append(model) - # convert to ProviderWithModelsResponse list providers_with_models: list[ProviderWithModelsResponse] = [] for provider, models in provider_models.items(): if not models: continue - first_model = models[0] - providers_with_models.append( ProviderWithModelsResponse( tenant_id=tenant_id, @@ -318,14 +280,12 @@ class ModelProviderService: ], ) ) - return providers_with_models def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]: """ get model parameter rules. Only supports LLM. - :param tenant_id: workspace id :param provider: provider name :param model: model name @@ -333,34 +293,27 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # fetch credentials credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) - if not credentials: return [] - model_schema = provider_configuration.get_model_schema( model_type=ModelType.LLM, model=model, credentials=credentials ) - return model_schema.parameter_rules if model_schema else [] def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ get default model of model type. - :param tenant_id: workspace id :param model_type: model type :return: """ model_type_enum = ModelType.value_of(model_type) - try: result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) return ( @@ -386,7 +339,6 @@ class ModelProviderService: def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ update default model of model type. - :param tenant_id: workspace id :param model_type: model type :param provider: provider name @@ -403,7 +355,6 @@ class ModelProviderService: ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. - :param tenant_id: workspace id :param provider: provider name :param icon_type: icon type (icon_small or icon_large) @@ -412,13 +363,11 @@ class ModelProviderService: """ model_provider_factory = ModelProviderFactory(tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) - return byte_data, mime_type def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: """ switch preferred provider. - :param tenant_id: workspace id :param provider: provider name :param preferred_provider_type: preferred provider type @@ -426,22 +375,18 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Convert preferred_provider_type to ProviderType preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ enable model. - :param tenant_id: workspace id :param provider: provider name :param model: model name @@ -450,19 +395,16 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Enable model provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ disable model. - :param tenant_id: workspace id :param provider: provider name :param model: model name @@ -471,11 +413,9 @@ class ModelProviderService: """ # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Enable model provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 8c8b64bcd5..005c762033 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -10,10 +10,8 @@ class OperationService: @classmethod def _send_request(cls, method, endpoint, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} - url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) - return response.json() @classmethod diff --git a/api/services/ops_service.py b/api/services/ops_service.py index c88accb9a5..7bd2b6f3f0 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -20,10 +20,8 @@ class OpsService: .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) - if not trace_config_data: return None - # decrypt_token and obfuscated_token app = db.session.query(App).filter(App.id == app_id).first() if not app: @@ -33,7 +31,6 @@ class OpsService: tenant_id, tracing_provider, trace_config_data.tracing_config ) new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) - if tracing_provider == "arize" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): @@ -42,7 +39,6 @@ class OpsService: new_decrypt_tracing_config.update({"project_url": project_url}) except Exception: new_decrypt_tracing_config.update({"project_url": "https://app.arize.com/"}) - if tracing_provider == "phoenix" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): @@ -51,7 +47,6 @@ class OpsService: new_decrypt_tracing_config.update({"project_url": project_url}) except Exception: new_decrypt_tracing_config.update({"project_url": "https://app.phoenix.arize.com/projects/"}) - if tracing_provider == "langfuse" and ( "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") ): @@ -68,7 +63,6 @@ class OpsService: new_decrypt_tracing_config.update( {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} ) - if tracing_provider == "langsmith" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): @@ -77,7 +71,6 @@ class OpsService: new_decrypt_tracing_config.update({"project_url": project_url}) except Exception: new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) - if tracing_provider == "opik" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): @@ -110,20 +103,16 @@ class OpsService: provider_config_map[tracing_provider] except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} - provider_config: dict[str, Any] = provider_config_map[tracing_provider] config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"] - default_config_instance: BaseTracingConfig = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) - # api check if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} - # get project url if tracing_provider in ("arize", "phoenix"): project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) @@ -134,17 +123,14 @@ class OpsService: project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) else: project_url = None - # check if trace config already exists trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) - if trace_config_data: return None - # get tenant id app = db.session.query(App).filter(App.id == app_id).first() if not app: @@ -160,7 +146,6 @@ class OpsService: ) db.session.add(trace_config_data) db.session.commit() - return {"result": "success"} @classmethod @@ -176,17 +161,14 @@ class OpsService: provider_config_map[tracing_provider] except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - # check if trace config already exists current_trace_config = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) - if not current_trace_config: return None - # get tenant id app = db.session.query(App).filter(App.id == app_id).first() if not app: @@ -195,16 +177,13 @@ class OpsService: tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) - # api check # decrypt_token decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if not OpsTraceManager.check_trace_config_is_effective(decrypt_tracing_config, tracing_provider): raise ValueError("Invalid Credentials") - current_trace_config.tracing_config = tracing_config db.session.commit() - return current_trace_config.to_dict() @classmethod @@ -220,11 +199,8 @@ class OpsService: .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) - if not trace_config: return None - db.session.delete(trace_config) db.session.commit() - return True diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 5324036414..54716e66e1 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -28,9 +28,7 @@ class PluginDataMigration: def migrate_datasets(cls) -> None: table_name = "datasets" provider_column_name = "embedding_model_provider" - click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) - processed_count = 0 failed_ids = [] while True: @@ -39,17 +37,14 @@ where {provider_column_name} not like '%/%' and {provider_column_name} is not nu limit 1000""" with db.engine.begin() as conn: rs = conn.execute(db.text(sql)) - current_iter_count = 0 for i in rs: record_id = str(i.id) provider_name = str(i.provider_name) retrieval_model = i.retrieval_model print(type(retrieval_model)) - if record_id in failed_ids: continue - retrieval_model_changed = False if retrieval_model: if ( @@ -71,14 +66,12 @@ limit 1000""" retrieval_model["reranking_model"]["reranking_provider_name"] ).to_string() retrieval_model_changed = True - click.echo( click.style( f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})", fg="white", ) ) - try: # update provider name append with "langgenius/{provider_name}/{provider_name}" params = {"record_id": record_id} @@ -86,9 +79,7 @@ limit 1000""" if retrieval_model and retrieval_model_changed: update_retrieval_model_sql = ", retrieval_model = :retrieval_model" params["retrieval_model"] = json.dumps(retrieval_model) - params["provider_name"] = ModelProviderID(provider_name).to_string() - sql = f"""update {table_name} set {provider_column_name} = :provider_name @@ -113,13 +104,10 @@ limit 1000""" f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" ) continue - current_iter_count += 1 processed_count += 1 - if not current_iter_count: break - click.echo( click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green") ) @@ -129,11 +117,9 @@ limit 1000""" cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] ) -> None: click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) - processed_count = 0 failed_ids = [] last_id = "00000000-0000-0000-0000-000000000000" - while True: sql = f""" SELECT id, {provider_column_name} AS provider_name @@ -146,30 +132,24 @@ limit 1000""" LIMIT 5000 """ params = {"last_id": last_id or ""} - with db.engine.begin() as conn: rs = conn.execute(db.text(sql), params) - current_iter_count = 0 batch_updates = [] - for i in rs: current_iter_count += 1 processed_count += 1 record_id = str(i.id) last_id = record_id provider_name = str(i.provider_name) - if record_id in failed_ids: continue - click.echo( click.style( f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})", fg="white", ) ) - try: # update jina to langgenius/jina_tool/jina etc. updated_value = provider_cls(provider_name).to_string() @@ -186,7 +166,6 @@ limit 1000""" f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" ) continue - if batch_updates: update_sql = f""" UPDATE {table_name} @@ -200,10 +179,8 @@ limit 1000""" fg="green", ) ) - if not current_iter_count: break - click.echo( click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green") ) diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py index 830d3a4769..d32f09628e 100644 --- a/api/services/plugin/dependencies_analysis.py +++ b/api/services/plugin/dependencies_analysis.py @@ -9,7 +9,6 @@ class DependenciesAnalysisService: def analyze_tool_dependency(cls, tool_id: str) -> str: """ Analyze the dependency of a tool. - Convert the tool id to the plugin_id """ try: @@ -21,7 +20,6 @@ class DependenciesAnalysisService: def analyze_model_provider_dependency(cls, model_provider_id: str) -> str: """ Analyze the dependency of a model provider. - Convert the model provider id to the plugin_id """ try: @@ -37,13 +35,10 @@ class DependenciesAnalysisService: required_plugin_unique_identifiers = [] for dependency in dependencies: required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier) - manager = PluginInstaller() - # get leaked dependencies missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers) missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins} - leaked_dependencies = [] for dependency in dependencies: unique_identifier = dependency.value.plugin_unique_identifier @@ -55,7 +50,6 @@ class DependenciesAnalysisService: current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier, ) ) - return leaked_dependencies @classmethod @@ -103,7 +97,6 @@ class DependenciesAnalysisService: ) else: raise ValueError(f"Unknown plugin source: {plugin.source}") - return result @classmethod diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index b84dd0afc5..3fab56e4dc 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -14,11 +14,9 @@ class OAuthProxyService(BasePluginClient): def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): """ Create a proxy context for an OAuth 2.0 authorization request. - This parameter is a crucial security measure to prevent Cross-Site Request Forgery (CSRF) attacks. It works by generating a unique nonce and storing it in a distributed cache (Redis) along with the user's session context. - The returned nonce should be included as the 'proxy_context' parameter in the authorization URL. Upon callback, the `use_proxy_context` method is used to verify the state, ensuring the request's integrity and authenticity, diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dbaaa7160e..955439471c 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -26,7 +26,6 @@ from models.tools import BuiltinToolProvider from models.workflow import Workflow logger = logging.getLogger(__name__) - excluded_providers = ["time", "audio", "code", "webscraper"] @@ -42,16 +41,12 @@ class PluginMigration: ended_at = datetime.datetime.now() started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) current_time = started_at - with Session(db.engine) as session: total_tenant_count = session.query(Tenant.id).count() - click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) - handled_tenant_count = 0 file_lock = Lock() counter_lock = Lock() - thread_pool = ThreadPoolExecutor(max_workers=workers) def process_tenant(flask_app: Flask, tenant_id: str) -> None: @@ -63,7 +58,6 @@ class PluginMigration: with file_lock: with open(filepath, "a") as f: f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n") - # Use lock when updating counter with counter_lock: nonlocal handled_tenant_count @@ -81,7 +75,6 @@ class PluginMigration: logger.exception(f"Failed to process tenant {tenant_id}") futures = [] - while current_time < ended_at: click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")) # Initial interval of 1 day, will be dynamically adjusted based on tenant count @@ -97,7 +90,6 @@ class PluginMigration: datetime.timedelta(hours=3), datetime.timedelta(hours=1), ] - for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) @@ -110,7 +102,6 @@ class PluginMigration: else: # If all intervals have too many tenants, use minimum interval interval = datetime.timedelta(hours=1) - # Adjust interval to target ~100 tenants per batch if tenant_count > 0: # Scale interval based on ratio to target count @@ -121,15 +112,12 @@ class PluginMigration: interval * (100 / tenant_count), # Scale to target 100 ), ) - batch_end = min(current_time + interval, ended_at) - rs = ( session.query(Tenant.id) .filter(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) - tenants = [] for row in rs: tenant_id = str(row.id) @@ -138,7 +126,6 @@ class PluginMigration: except Exception: logger.exception(f"Failed to process tenant {tenant_id}") continue - futures.append( thread_pool.submit( process_tenant, @@ -146,9 +133,7 @@ class PluginMigration: tenant_id, ) ) - current_time = batch_end - # wait for all threads to finish for future in futures: future.result() @@ -162,14 +147,12 @@ class PluginMigration: models = cls.extract_model_tables(tenant_id) workflows = cls.extract_workflow_tables(tenant_id) apps = cls.extract_app_tables(tenant_id) - return list({*tools, *models, *workflows, *apps}) @classmethod def extract_model_tables(cls, tenant_id: str) -> Sequence[str]: """ Extract model tables. - """ models: list[str] = [] table_pairs = [ @@ -181,13 +164,10 @@ class PluginMigration: ("provider_model_settings", "provider_name"), ("load_balancing_model_configs", "provider_name"), ] - for table, column in table_pairs: models.extend(cls.extract_model_table(tenant_id, table, column)) - # duplicate models models = list(set(models)) - return models @classmethod @@ -203,7 +183,6 @@ class PluginMigration: for row in rs: provider_name = str(row[0]) result.append(ModelProviderID(provider_name).plugin_id) - return result @classmethod @@ -216,7 +195,6 @@ class PluginMigration: result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) - return result @classmethod @@ -224,7 +202,6 @@ class PluginMigration: """ Extract workflow tables, only ToolNode is required. """ - with Session(db.engine) as session: rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all() result = [] @@ -232,7 +209,6 @@ class PluginMigration: graph = row.graph_dict # get nodes nodes = graph.get("nodes", []) - for node in nodes: data = node.get("data", {}) if data.get("type") == "tool": @@ -240,7 +216,6 @@ class PluginMigration: provider_type = data.get("provider_type") if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: result.append(ToolProviderID(provider_name).plugin_id) - return result @classmethod @@ -252,11 +227,9 @@ class PluginMigration: apps = session.query(App).filter(App.tenant_id == tenant_id).all() if not apps: return [] - agent_app_model_config_ids = [ app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value ] - rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all() result = [] for row in rs: @@ -271,11 +244,9 @@ class PluginMigration: and tool_entity.provider_id not in excluded_providers ): result.append(ToolProviderID(tool_entity.provider_id).plugin_id) - except Exception: logger.exception(f"Failed to process tool {tool}") continue - return result @classmethod @@ -286,7 +257,6 @@ class PluginMigration: plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id]) if not plugin_manifest: return None - return plugin_manifest[0].latest_package_identifier @classmethod @@ -323,7 +293,6 @@ class PluginMigration: with ThreadPoolExecutor(max_workers=10) as executor: list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids))) - return {"plugins": plugins, "plugin_not_exist": plugin_not_exist} @classmethod @@ -332,17 +301,13 @@ class PluginMigration: Install plugins. """ manager = PluginInstaller() - plugins = cls.extract_unique_plugins(extracted_plugins) not_installed = [] plugin_install_failed = [] - # use a fake tenant id to install all the plugins fake_tenant_id = uuid4().hex logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}") - thread_pool = ThreadPoolExecutor(max_workers=workers) - response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"]) if response.get("failed"): plugin_install_failed.extend(response.get("failed", [])) @@ -389,27 +354,20 @@ class PluginMigration: unique_identifier = plugins.get(plugin_id) if unique_identifier: current_not_installed["plugin_not_exist"].append(plugin_id) - if current_not_installed["plugin_not_exist"]: not_installed.append(current_not_installed) - thread_pool.submit(install, tenant_id, plugin_ids) - thread_pool.shutdown(wait=True) - logger.info("Uninstall plugins") - # get installation try: installation = manager.list_plugins(fake_tenant_id) while installation: for plugin in installation: manager.uninstall(fake_tenant_id, plugin.installation_id) - installation = manager.list_plugins(fake_tenant_id) except Exception: logger.exception(f"Failed to get installation for tenant {fake_tenant_id}") - Path(output_file).write_text( json.dumps( { @@ -427,38 +385,30 @@ class PluginMigration: Install plugins for a tenant. """ manager = PluginInstaller() - # download all the plugins and upload thread_pool = ThreadPoolExecutor(max_workers=10) futures = [] - for plugin_id, plugin_identifier in plugin_identifiers_map.items(): def download_and_upload(tenant_id, plugin_id, plugin_identifier): plugin_package = marketplace.download_plugin_pkg(plugin_identifier) if not plugin_package: raise Exception(f"Failed to download plugin {plugin_identifier}") - # upload manager.upload_pkg(tenant_id, plugin_package, verify_signature=True) futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier)) - # Wait for all downloads to complete for future in futures: future.result() # This will raise any exceptions that occurred - thread_pool.shutdown(wait=True) success = [] failed = [] - reverse_map = {v: k for k, v in plugin_identifiers_map.items()} - # at most 8 plugins one batch for i in range(0, len(plugin_identifiers_map), 8): batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8] batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids] - try: response = manager.install_from_identifiers( tenant_id=tenant_id, @@ -475,11 +425,9 @@ class PluginMigration: # add to failed failed.extend(batch_plugin_identifiers) continue - if response.all_installed: success.extend(batch_plugin_identifiers) continue - task_id = response.task_id done = False while not done: @@ -493,9 +441,7 @@ class PluginMigration: logger.error( f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}" ) - done = True else: time.sleep(1) - return {"success": success, "failed": failed} diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e2..ce0e6e5391 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -24,7 +24,6 @@ class PluginParameterService: ) -> Sequence[PluginParameterOption]: """ Get dynamic select options for a plugin parameter. - Args: tenant_id: The tenant ID. plugin_id: The plugin ID. @@ -33,7 +32,6 @@ class PluginParameterService: parameter: The parameter name. """ credentials: Mapping[str, Any] = {} - match provider_type: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -44,7 +42,6 @@ class PluginParameterService: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - # check if credentials are required if not provider_controller.need_credentials: credentials = {} @@ -59,14 +56,11 @@ class PluginParameterService: ) .first() ) - if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") - return ( DynamicSelectClient() .fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter) diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 275e496037..e5de01a2ef 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -24,11 +24,9 @@ class PluginPermissionService: permission = TenantPluginPermission( tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission ) - session.add(permission) else: permission.install_permission = install_permission permission.debug_permission = debug_permission - session.commit() return True diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index d7fb4a7c1b..d6c1a33364 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -48,10 +48,8 @@ class PluginService: Fetch the latest plugin version """ result: dict[str, Optional[PluginService.LatestPluginCache]] = {} - try: cache_not_exists = [] - # Try to get from Redis first for plugin_id in plugin_ids: cached_data = redis_client.get(f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}") @@ -59,35 +57,28 @@ class PluginService: result[plugin_id] = PluginService.LatestPluginCache.model_validate_json(cached_data) else: cache_not_exists.append(plugin_id) - if cache_not_exists: manifests = { manifest.plugin_id: manifest for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists) } - for plugin_id, manifest in manifests.items(): latest_plugin = PluginService.LatestPluginCache( plugin_id=plugin_id, version=manifest.latest_version, unique_identifier=manifest.latest_package_identifier, ) - # Store in Redis redis_client.setex( f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}", PluginService.REDIS_TTL, latest_plugin.model_dump_json(), ) - result[plugin_id] = latest_plugin - # pop plugin_id from cache_not_exists cache_not_exists.remove(plugin_id) - for plugin_id in cache_not_exists: result[plugin_id] = None - return result except Exception: logger.exception("failed to fetch latest plugin version") @@ -108,7 +99,6 @@ class PluginService: Check the plugin installation scope """ features = FeatureService.get_system_features() - match features.plugin_installation_permission.plugin_installation_scope: case PluginInstallationScope.OFFICIAL_ONLY: if ( @@ -244,15 +234,11 @@ class PluginService: """ if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") - if original_plugin_unique_identifier == new_plugin_unique_identifier: raise ValueError("you should not upgrade plugin with the same plugin") - # check if plugin pkg is already downloaded manager = PluginInstaller() - features = FeatureService.get_system_features() - try: manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier) # already downloaded, skip, and record install event @@ -265,10 +251,8 @@ class PluginService: pkg, verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, ) - # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) - return manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, @@ -309,7 +293,6 @@ class PluginService: def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: """ Upload plugin package files - returns: plugin_unique_identifier """ PluginService._check_marketplace_only_permission() @@ -335,7 +318,6 @@ class PluginService: f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE ) features = FeatureService.get_system_features() - manager = PluginInstaller() response = manager.upload_pkg( tenant_id, @@ -358,9 +340,7 @@ class PluginService: @staticmethod def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): PluginService._check_marketplace_only_permission() - manager = PluginInstaller() - return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, @@ -375,7 +355,6 @@ class PluginService: returns plugin_unique_identifier """ PluginService._check_marketplace_only_permission() - manager = PluginInstaller() return manager.install_from_identifiers( tenant_id, @@ -397,9 +376,7 @@ class PluginService: """ if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") - features = FeatureService.get_system_features() - manager = PluginInstaller() try: declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) @@ -413,7 +390,6 @@ class PluginService: # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) declaration = response.manifest - return declaration @staticmethod @@ -424,11 +400,8 @@ class PluginService: """ if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") - manager = PluginInstaller() - features = FeatureService.get_system_features() - # check if already downloaded for plugin_unique_identifier in plugin_unique_identifiers: try: @@ -447,7 +420,6 @@ class PluginService: ) # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) - return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 523aebeed5..23e888e608 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -35,12 +35,10 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): """ if cls.builtin_data: return cls.builtin_data - root_path = current_app.root_path cls.builtin_data = json.loads( Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data or {} @classmethod diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 3295516cce..ab433e3d33 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -36,25 +36,21 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) .all() ) - if len(recommended_apps) == 0: recommended_apps = ( db.session.query(RecommendedApp) .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) .all() ) - categories = set() recommended_apps_result = [] for recommended_app in recommended_apps: app = recommended_app.app if not app or not app.is_public: continue - site = app.site if not site: continue - recommended_app_result = { "id": recommended_app.id, "app": recommended_app.app, @@ -68,9 +64,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): "is_listed": recommended_app.is_listed, } recommended_apps_result.append(recommended_app_result) - categories.add(recommended_app.category) - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod @@ -86,15 +80,12 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) .first() ) - if not recommended_app: return None - # get app detail app_model = db.session.query(App).filter(App.id == app_id).first() if not app_model or not app_model.is_public: return None - return { "id": app_model.id, "name": app_model.name, diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 80e1aefc01..4001982e9f 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -62,10 +62,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result: dict = response.json() - if "categories" in result: result["categories"] = sorted(result["categories"]) - return result diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 54c5845515..514a7a5bcb 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -21,7 +21,6 @@ class RecommendedAppService: "en-US" ) ) - return result @classmethod diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4cb8700117..82f390ef6a 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -26,7 +26,6 @@ class SavedMessageService: .all() ) message_ids = [sm.message_id for sm in saved_messages] - return MessageService.pagination_by_last_id( app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @@ -45,19 +44,15 @@ class SavedMessageService: ) .first() ) - if saved_message: return - message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) - saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, created_by_role="account" if isinstance(user, Account) else "end_user", created_by=user.id, ) - db.session.add(saved_message) db.session.commit() @@ -75,9 +70,7 @@ class SavedMessageService: ) .first() ) - if not saved_message: return - db.session.delete(saved_message) db.session.commit() diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 74c6150b44..3562294105 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -70,7 +70,6 @@ class TagService: ) .all() ) - return tags or [] @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c4..1a2c11de2b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -39,7 +39,6 @@ class ApiToolManageService: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") - credentials_schema = [ ProviderConfig( name="auth_type", @@ -68,7 +67,6 @@ class ApiToolManageService: default="", ), ] - return cast( Mapping, jsonable_encoder( @@ -87,7 +85,6 @@ class ApiToolManageService: def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: """ convert schema to tool bundles - :return: the list of tool bundles, description """ try: @@ -113,9 +110,7 @@ class ApiToolManageService: """ if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() - # check if the provider exists provider = ( db.session.query(ApiToolProvider) @@ -125,18 +120,14 @@ class ApiToolManageService: ) .first() ) - if provider is not None: raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: raise ValueError("the number of apis should be less than 100") - # create db provider db_provider = ApiToolProvider( tenant_id=tenant_id, @@ -151,18 +142,14 @@ class ApiToolManageService: privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, ) - if "auth_type" not in credentials: raise ValueError("auth_type is required") - # get auth type, none or api key auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, @@ -170,16 +157,12 @@ class ApiToolManageService: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - encrypted_credentials = tool_configuration.encrypt(credentials) db_provider.credentials_str = json.dumps(encrypted_credentials) - db.session.add(db_provider) db.session.commit() - # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod @@ -192,19 +175,16 @@ class ApiToolManageService: " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", "Accept": "*/*", } - try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: raise ValueError(f"Got status code {response.status_code}") schema = response.text - # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception: logger.exception("parse api schema error") raise ValueError("invalid schema, please check the url you provided") - return {"schema": schema} @staticmethod @@ -220,13 +200,10 @@ class ApiToolManageService: ) .first() ) - if provider is None: raise ValueError(f"you have not added provider {provider_name}") - controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - return [ ToolTransformService.convert_tool_entity_to_api_entity( tool_bundle, @@ -255,9 +232,7 @@ class ApiToolManageService: """ if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() - # check if the provider exists provider = ( db.session.query(ApiToolProvider) @@ -267,14 +242,12 @@ class ApiToolManageService: ) .first() ) - if provider is None: raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - # update db provider provider.name = provider_name provider.icon = json.dumps(icon) @@ -284,18 +257,14 @@ class ApiToolManageService: provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if "auth_type" not in credentials: raise ValueError("auth_type is required") - # get auth type, none or api key auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) - # get original credentials if exists tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, @@ -303,26 +272,20 @@ class ApiToolManageService: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - original_credentials = tool_configuration.decrypt(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) provider.credentials_str = json.dumps(credentials) - db.session.add(provider) db.session.commit() - # delete cache tool_configuration.delete_tool_credentials_cache() - # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod @@ -338,13 +301,10 @@ class ApiToolManageService: ) .first() ) - if provider is None: raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) db.session.commit() - return {"result": "success"} @staticmethod @@ -369,17 +329,14 @@ class ApiToolManageService: """ if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") - try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception: raise ValueError("invalid schema") - # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = ( db.session.query(ApiToolProvider) .filter( @@ -388,7 +345,6 @@ class ApiToolManageService: ) .first() ) - if not db_provider: # create a fake db provider db_provider = ApiToolProvider( @@ -402,18 +358,14 @@ class ApiToolManageService: tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if "auth_type" not in credentials: raise ValueError("auth_type is required") - # get auth type, none or api key auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) - # decrypt credentials if db_provider.id: tool_configuration = ProviderConfigEncrypter( @@ -428,7 +380,6 @@ class ApiToolManageService: for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] - try: provider_controller.validate_credentials_format(credentials) # get tool @@ -442,7 +393,6 @@ class ApiToolManageService: result = tool.validate_credentials(credentials, parameters) except Exception as e: return {"error": str(e)} - return {"result": result or "empty response"} @staticmethod @@ -454,9 +404,7 @@ class ApiToolManageService: db_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] ) - result: list[ToolProviderApiEntity] = [] - for provider in db_providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) @@ -465,19 +413,14 @@ class ApiToolManageService: provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels - # add icon ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider) - tools = provider_controller.get_tools(tenant_id=tenant_id) - for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels ) ) - result.append(user_provider) - return result diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58a4b2f179..edcfba0c4a 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -27,15 +27,12 @@ class BuiltinToolManageService: def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ list builtin tool provider tools - :param tenant_id: the id of the tenant :param provider: the name of the provider - :return: the list of tools """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], @@ -44,13 +41,11 @@ class BuiltinToolManageService: ) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - credentials = {} if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( @@ -61,7 +56,6 @@ class BuiltinToolManageService: labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) - return result @staticmethod @@ -78,28 +72,23 @@ class BuiltinToolManageService: ) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - credentials = {} if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt(credentials) - entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=builtin_provider, decrypt_credentials=True, ) - entity.original_credentials = {} - return entity @staticmethod def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): """ list builtin provider credentials schema - :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers @@ -116,7 +105,6 @@ class BuiltinToolManageService: """ # get if the provider exists provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) - try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) @@ -128,7 +116,6 @@ class BuiltinToolManageService: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - # get original credentials if exists if provider is not None: original_credentials = tool_configuration.decrypt(provider.credentials) @@ -148,7 +135,6 @@ class BuiltinToolManageService: ToolProviderCredentialValidationError, ) as e: raise ValueError(str(e)) - if provider is None: # create provider provider = BuiltinToolProvider( @@ -157,14 +143,11 @@ class BuiltinToolManageService: provider=provider_name, encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - # delete cache tool_configuration.delete_tool_credentials_cache() - db.session.commit() return {"result": "success"} @@ -174,10 +157,8 @@ class BuiltinToolManageService: get builtin tool provider credentials """ provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) - if provider_obj is None: return {} - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, @@ -195,13 +176,10 @@ class BuiltinToolManageService: delete tool provider """ provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) - if provider_obj is None: raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider_obj) db.session.commit() - # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( @@ -211,7 +189,6 @@ class BuiltinToolManageService: provider_identity=provider_controller.entity.identity.name, ) tool_configuration.delete_tool_credentials_cache() - return {"result": "success"} @staticmethod @@ -221,7 +198,6 @@ class BuiltinToolManageService: """ icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider) icon_bytes = Path(icon_path).read_bytes() - return icon_bytes, mime_type @staticmethod @@ -231,13 +207,11 @@ class BuiltinToolManageService: """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers(tenant_id) - with db.session.no_autoflush: # get all user added providers db_providers: list[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] ) - # rewrite db_providers for db_provider in db_providers: db_provider.provider = str(ToolProviderID(db_provider.provider)) @@ -247,7 +221,6 @@ class BuiltinToolManageService: return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) result: list[ToolProviderApiEntity] = [] - for provider_controller in provider_controllers: try: # handle include, exclude @@ -258,17 +231,14 @@ class BuiltinToolManageService: name_func=lambda x: x.identity.name, ): continue - # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.entity.identity.name), decrypt_credentials=True, ) - # add icon ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) - tools = provider_controller.get_tools() for tool in tools or []: user_builtin_provider.tools.append( @@ -279,11 +249,9 @@ class BuiltinToolManageService: labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) - result.append(user_builtin_provider) except Exception as e: raise e - return BuiltinToolProviderSort.sort(result) @staticmethod @@ -311,10 +279,8 @@ class BuiltinToolManageService: ) .first() ) - if provider_obj is None: return None - provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() return provider_obj except Exception: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 59d5b50e23..6c97a5b7eb 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -12,15 +12,11 @@ class ToolCommonService: def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None): """ list tool providers - :return: the list of tool providers """ providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) - # add icon for provider in providers: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) - result = [provider.to_dict() for provider in providers] - return result diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..b6b0c2796a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -42,7 +42,6 @@ class ToolTransformService: url_prefix = ( URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider" ) - if provider_type == ToolProviderType.BUILT_IN.value: return str(url_prefix / "builtin" / provider_name / "icon") elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: @@ -52,14 +51,12 @@ class ToolTransformService: return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - return "" @staticmethod def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]): """ repack provider - :param tenant_id: the tenant id :param provider: the provider dict """ @@ -102,28 +99,22 @@ class ToolTransformService: tools=[], labels=provider_controller.tool_labels, ) - if isinstance(provider_controller, PluginToolProviderController): result.plugin_id = provider_controller.plugin_id result.plugin_unique_identifier = provider_controller.plugin_unique_identifier - # get credentials schema schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} - for name, value in schema.items(): if result.masked_credentials: result.masked_credentials[name] = "" - # check if the provider need credentials if not provider_controller.need_credentials: result.is_team_authorization = True result.allow_delete = False elif db_provider: result.is_team_authorization = True - if decrypt_credentials: credentials = db_provider.credentials - # init tool configuration tool_configuration = ProviderConfigEncrypter( tenant_id=db_provider.tenant_id, @@ -134,10 +125,8 @@ class ToolTransformService: # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt(data=credentials) masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) - result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials - return result @staticmethod @@ -154,7 +143,6 @@ class ToolTransformService: if db_provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) - return controller @staticmethod @@ -205,7 +193,6 @@ class ToolTransformService: user = db_provider.user if not user: raise ValueError("user not found") - username = user.name except Exception: logger.exception(f"failed to get user name for api provider {db_provider.id}") @@ -232,7 +219,6 @@ class ToolTransformService: tools=[], labels=labels or [], ) - if decrypt_credentials: # init tool configuration tool_configuration = ProviderConfigEncrypter( @@ -241,13 +227,10 @@ class ToolTransformService: provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt(data=credentials) masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) - result.masked_credentials = masked_credentials - return result @staticmethod @@ -268,7 +251,6 @@ class ToolTransformService: tenant_id=tenant_id, ) ) - # get tool parameters parameters = tool.entity.parameters or [] # get tool runtime parameters @@ -282,10 +264,8 @@ class ToolTransformService: current_parameters[index] = runtime_parameter found = True break - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) - return ToolApiEntity( author=tool.entity.identity.author, name=tool.entity.identity.name, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index c6b205557a..00aca08a9c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -39,7 +39,6 @@ class WorkflowToolManageService: labels: list[str] | None = None, ) -> dict: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -50,19 +49,14 @@ class WorkflowToolManageService: ) .first() ) - if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() - if app is None: raise ValueError(f"App {workflow_app_id} not found") - workflow: Workflow | None = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") - workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -75,15 +69,12 @@ class WorkflowToolManageService: privacy_policy=privacy_policy, version=workflow.version, ) - try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) db.session.commit() - if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels @@ -119,7 +110,6 @@ class WorkflowToolManageService: :return: the updated tool """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -130,30 +120,23 @@ class WorkflowToolManageService: ) .first() ) - if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) - if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App | None = ( db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) - if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") - workflow: Workflow | None = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") - workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -162,20 +145,16 @@ class WorkflowToolManageService: workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() - try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - db.session.add(workflow_tool_provider) db.session.commit() - if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return {"result": "success"} @classmethod @@ -187,7 +166,6 @@ class WorkflowToolManageService: :return: the list of tools """ db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - tools: list[WorkflowToolProviderController] = [] for provider in db_tools: try: @@ -195,11 +173,8 @@ class WorkflowToolManageService: except Exception: # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) - result = [] - for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( provider_controller=tool, labels=labels.get(tool.provider_id, []) @@ -213,7 +188,6 @@ class WorkflowToolManageService: ) ] result.append(user_tool_provider) - return result @classmethod @@ -227,9 +201,7 @@ class WorkflowToolManageService: db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() - db.session.commit() - return {"result": "success"} @classmethod @@ -273,23 +245,18 @@ class WorkflowToolManageService: """ if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = ( db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() ) - if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") - workflow = workflow_app.workflow if not workflow: raise ValueError("Workflow not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {db_tool.id} not found") - return { "name": db_tool.name, "label": db_tool.label, @@ -321,15 +288,12 @@ class WorkflowToolManageService: .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) - if db_tool is None: raise ValueError(f"Tool {workflow_tool_id} not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") - return [ ToolTransformService.convert_tool_entity_to_api_entity( tool=tool.get_tools(db_tool.tenant_id)[0], diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 9165139193..da5a9f4d62 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -22,7 +22,6 @@ class VectorService: cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] - for segment in segments: if doc_form == IndexType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() @@ -45,7 +44,6 @@ class VectorService: if dataset.indexing_technique == "high_quality": # check embedding model setting model_manager = ModelManager() - if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( tenant_id=dataset.tenant_id, @@ -81,7 +79,6 @@ class VectorService: @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): # update segment index task - # format new index document = Document( page_content=segment.content, @@ -101,7 +98,6 @@ class VectorService: # update keyword index keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) - # save keyword index if keywords and len(keywords) > 0: keyword.add_texts([document], keywords_list=[keywords]) @@ -122,7 +118,6 @@ class VectorService: if regenerate: # delete child chunks index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) - # generate child chunks document = Document( page_content=segment.content, @@ -146,7 +141,6 @@ class VectorService: # save child chunks if documents and documents[0].children: index_processor.load(dataset, documents) - for position, child_chunk in enumerate(documents[0].children, start=1): child_segment = ChildChunk( tenant_id=dataset.tenant_id, diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index f698ed3084..25ff658fef 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -41,12 +41,10 @@ class WebConversationService: .order_by(PinnedConversation.created_at.desc()) ) pinned_conversation_ids = session.scalars(stmt).all() - if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids - return ConversationService.pagination_by_last_id( session=session, app_model=app_model, @@ -73,21 +71,17 @@ class WebConversationService: ) .first() ) - if pinned_conversation: return - conversation = ConversationService.get_conversation( app_model=app_model, conversation_id=conversation_id, user=user ) - pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, created_by_role="account" if isinstance(user, Account) else "end_user", created_by=user.id, ) - db.session.add(pinned_conversation) db.session.commit() @@ -105,9 +99,7 @@ class WebConversationService: ) .first() ) - if not pinned_conversation: return - db.session.delete(pinned_conversation) db.session.commit() diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 8f92b3f070..32eb9b07ae 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -35,19 +35,15 @@ class WebAppAuthService: account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError() - if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") - if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") - return cast(Account, account) @classmethod def login(cls, account: Account) -> str: access_token = cls._get_account_jwt_token(account=account) - return access_token @classmethod @@ -55,10 +51,8 @@ class WebAppAuthService: account = db.session.query(Account).filter(Account.email == email).first() if not account: return None - if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") - return account @classmethod @@ -68,7 +62,6 @@ class WebAppAuthService: email = account.email if account else email if email is None: raise ValueError("Email must be provided.") - code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( account=account, email=email, token_type="email_code_login", additional_data={"code": code} @@ -78,7 +71,6 @@ class WebAppAuthService: to=account.email if account else email, code=code, ) - return token @classmethod @@ -108,14 +100,12 @@ class WebAppAuthService: ) db.session.add(end_user) db.session.commit() - return end_user @classmethod def _get_account_jwt_token(cls, account: Account) -> str: exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp = int(exp_dt.timestamp()) - payload = { "sub": "Web API Passport", "user_id": account.id, @@ -124,7 +114,6 @@ class WebAppAuthService: "auth_type": "internal", "exp": exp, } - token: str = PassportService().issue(payload) return token @@ -141,15 +130,12 @@ class WebAppAuthService: ] if access_mode: return access_mode in modes_requiring_permission_check - if not app_code and not app_id: raise ValueError("Either app_code or app_id must be provided.") - if app_code: app_id = AppService.get_app_id_by_code(app_code) if not app_id: raise ValueError("App ID could not be determined from the provided app_code.") - webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: return True @@ -162,7 +148,6 @@ class WebAppAuthService: """ if not app_code and not access_mode: raise ValueError("Either app_code or access_mode must be provided.") - if access_mode: if access_mode == "public": return WebAppAuthType.PUBLIC @@ -170,9 +155,7 @@ class WebAppAuthService: return WebAppAuthType.INTERNAL elif access_mode == "sso_verified": return WebAppAuthType.EXTERNAL - if app_code: webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) - raise ValueError("Could not determine app authentication type.") diff --git a/api/services/website_service.py b/api/services/website_service.py index 6720932a3a..683b546506 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -66,7 +66,6 @@ class WebsiteService: tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options) - elif provider == "jinareader": api_key = encrypter.decrypt_token( tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") @@ -150,7 +149,6 @@ class WebsiteService: "data": [], "time_consuming": data.get("duration", 0) / 1000, } - if crawl_status_data["status"] == "completed": response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", @@ -177,7 +175,6 @@ class WebsiteService: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - if provider == "firecrawl": crawl_data: list[dict[str, Any]] | None = None file_key = "website_files/" + job_id + ".txt" @@ -191,7 +188,6 @@ class WebsiteService: if result.get("status") != "completed": raise ValueError("Crawl job is not completed") crawl_data = result.get("data") - if crawl_data: for item in crawl_data: if item.get("source_url") == url: @@ -221,7 +217,6 @@ class WebsiteService: status_data = status_response.json().get("data", {}) if status_data.get("status") != "completed": raise ValueError("Crawl job is not completed") - # Get processed data data_response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2b0d57bdfd..4a1cfbe0d7 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -37,13 +37,9 @@ class WorkflowConverter: ): """ Convert app to workflow - - basic mode of chatbot app - - expert mode of chatbot app - - completion app - :param app_model: App instance :param account: Account :param name: new app name @@ -55,11 +51,9 @@ class WorkflowConverter: # convert app model config if not app_model.app_model_config: raise ValueError("App model config is required") - workflow = self.convert_app_model_config_to_workflow( app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id ) - # create new app new_app = App() new_app.tenant_id = app_model.tenant_id @@ -79,12 +73,9 @@ class WorkflowConverter: db.session.add(new_app) db.session.flush() db.session.commit() - workflow.app_id = new_app.id db.session.commit() - app_was_created.send(new_app, account=account) - return new_app def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): @@ -96,13 +87,10 @@ class WorkflowConverter: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) - # convert app model config app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) - # init workflow graph graph: dict[str, Any] = {"nodes": [], "edges": []} - # Convert list: # - variables -> start # - model_config -> llm @@ -111,12 +99,9 @@ class WorkflowConverter: # - external_data_variables -> http-request # - dataset -> knowledge-retrieval # - show_retrieve_source -> knowledge-retrieval - # convert to start node start_node = self._convert_to_start_node(variables=app_config.variables) - graph["nodes"].append(start_node) - # convert to http request node external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: @@ -125,19 +110,15 @@ class WorkflowConverter: variables=app_config.variables, external_data_variables=app_config.external_data_variables, ) - for http_request_node in http_request_nodes: graph = self._append_node(graph, http_request_node) - # convert to knowledge retrieval node if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model ) - if knowledge_retrieval_node: graph = self._append_node(graph, knowledge_retrieval_node) - # convert to llm node llm_node = self._convert_to_llm_node( original_app_mode=AppMode.value_of(app_model.mode), @@ -148,9 +129,7 @@ class WorkflowConverter: file_upload=app_config.additional_features.file_upload, external_data_variable_node_mapping=external_data_variable_node_mapping, ) - graph = self._append_node(graph, llm_node) - if new_app_mode == AppMode.WORKFLOW: # convert to end node by app mode end_node = self._convert_to_end_node() @@ -158,9 +137,7 @@ class WorkflowConverter: else: answer_node = self._convert_to_answer_node() graph = self._append_node(graph, answer_node) - app_model_config_dict = app_config.app_model_config_dict - # features if new_app_mode == AppMode.ADVANCED_CHAT: features = { @@ -179,7 +156,6 @@ class WorkflowConverter: "file_upload": app_model_config_dict.get("file_upload"), "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), } - # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, @@ -192,10 +168,8 @@ class WorkflowConverter: environment_variables=[], conversation_variables=[], ) - db.session.add(workflow) db.session.commit() - return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: @@ -214,7 +188,6 @@ class WorkflowConverter: ) else: raise ValueError("Invalid app mode") - return app_config def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: @@ -251,27 +224,21 @@ class WorkflowConverter: tool_type = external_data_variable.type if tool_type != "api": continue - tool_variable = external_data_variable.variable tool_config = external_data_variable.config - # get params from config api_based_extension_id = tool_config.get("api_based_extension_id") if not api_based_extension_id: continue - # get api_based_extension api_based_extension = self._get_api_based_extension( tenant_id=tenant_id, api_based_extension_id=api_based_extension_id ) - # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) - inputs = {} for v in variables: inputs[v.variable] = "{{#start." + v.variable + "#}}" - request_body = { "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, "params": { @@ -281,10 +248,8 @@ class WorkflowConverter: "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", }, } - request_body_json = json.dumps(request_body) request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") - http_request_node = { "id": f"http_request_{index}", "position": None, @@ -299,9 +264,7 @@ class WorkflowConverter: "body": {"type": "json", "data": request_body_json}, }, } - nodes.append(http_request_node) - # append code node for response body parsing code_node: dict[str, Any] = { "id": f"code_{index}", @@ -316,12 +279,9 @@ class WorkflowConverter: "outputs": {"result": {"type": "string"}}, }, } - nodes.append(code_node) - external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] index += 1 - return nodes, external_data_variable_node_mapping def _convert_to_knowledge_retrieval_node( @@ -342,7 +302,6 @@ class WorkflowConverter: query_variable_selector = ["start", retrieve_config.query_variable] else: return None - return { "id": "knowledge_retrieval", "position": None, @@ -400,10 +359,8 @@ class WorkflowConverter: knowledge_retrieval_node = next( filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None ) - role_prefix = None prompts: Any = None - # Chat Model if model_config.mode == LLMMode.CHAT.value: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: @@ -419,7 +376,6 @@ class WorkflowConverter: has_context=knowledge_retrieval_node is not None, query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template if not template: prompts = [] @@ -427,11 +383,9 @@ class WorkflowConverter: template = self._replace_template_variables( template, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts = [{"role": "user", "text": template}] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template - prompts = [] if advanced_chat_prompt_template: for m in advanced_chat_prompt_template.messages: @@ -439,7 +393,6 @@ class WorkflowConverter: text = self._replace_template_variables( text, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts.append({"role": m.role.value, "text": text}) # Completion Model else: @@ -456,16 +409,13 @@ class WorkflowConverter: has_context=knowledge_retrieval_node is not None, query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template template = self._replace_template_variables( template=template, variables=start_node["data"]["variables"], external_data_variable_node_mapping=external_data_variable_node_mapping, ) - prompts = {"text": template} - prompt_rules = prompt_template_config["prompt_rules"] role_prefix = { "user": prompt_rules.get("human_prefix", "Human"), @@ -482,23 +432,18 @@ class WorkflowConverter: ) else: text = "" - text = text.replace("{{#query#}}", "{{#sys.query#}}") - prompts = { "text": text, } - if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: role_prefix = { "user": advanced_completion_prompt_template.role_prefix.user, "assistant": advanced_completion_prompt_template.role_prefix.assistant, } - memory = None if new_app_mode == AppMode.ADVANCED_CHAT: memory = {"role_prefix": role_prefix, "window": {"enabled": False}} - completion_params = model_config.parameters completion_params.update({"stop": model_config.stop}) return { @@ -543,11 +488,9 @@ class WorkflowConverter: """ for v in variables: template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") - if external_data_variable_node_mapping: for variable, code_node_id in external_data_variable_node_mapping.items(): template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") - return template def _convert_to_end_node(self) -> dict: @@ -590,7 +533,6 @@ class WorkflowConverter: def _append_node(self, graph: dict, node: dict) -> dict: """ Append Node to Graph - :param graph: Graph, include: nodes, edges :param node: Node to append :return: @@ -623,8 +565,6 @@ class WorkflowConverter: .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) - if not api_based_extension: raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") - return api_based_extension diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 6eabf03018..531e3be612 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -42,10 +42,8 @@ class WorkflowAppService: stmt = select(WorkflowAppLog).where( WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - if keyword or status: stmt = stmt.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) - if keyword: keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") keyword_conditions = [ @@ -54,27 +52,21 @@ class WorkflowAppService: # filter keyword by end user session id if created by end user role and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] - # filter keyword by workflow run id keyword_uuid = self._safe_parse_uuid(keyword) if keyword_uuid: keyword_conditions.append(WorkflowRun.id == keyword_uuid) - stmt = stmt.outerjoin( EndUser, and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), ).where(or_(*keyword_conditions)) - if status: stmt = stmt.where(WorkflowRun.status == status) - # Add time-based filtering if created_at_before: stmt = stmt.where(WorkflowAppLog.created_at <= created_at_before) - if created_at_after: stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after) - # Filter by end user session id or account email if created_by_end_user_session_id: stmt = stmt.join( @@ -94,19 +86,14 @@ class WorkflowAppService: Account.email == created_by_account, ), ) - stmt = stmt.order_by(WorkflowAppLog.created_at.desc()) - # Get total count using the same filters count_stmt = select(func.count()).select_from(stmt.subquery()) total = session.scalar(count_stmt) or 0 - # Apply pagination limits offset_stmt = stmt.offset((page - 1) * limit).limit(limit) - # Execute query and get items items = list(session.scalars(offset_stmt).all()) - return { "page": page, "limit": limit, @@ -120,7 +107,6 @@ class WorkflowAppService: # fast check if len(value) < 32: return None - try: return uuid.UUID(value) except ValueError: diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e4..9de94f96b8 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -52,7 +52,6 @@ class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # # ref: core.workflow.variable_loader.VariableLoader - # Database engine used for loading variables. _engine: Engine # Application ID for which variables are being loaded. @@ -78,14 +77,11 @@ class DraftVarLoader(VariableLoader): def load_variables(self, selectors: list[list[str]]) -> list[Variable]: if not selectors: return [] - # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. variable_by_selector: dict[tuple[str, str], Variable] = {} - with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) - for draft_var in draft_vars: segment = draft_var.get_value() variable = segment_to_variable( @@ -97,7 +93,6 @@ class DraftVarLoader(VariableLoader): ) selector_tuple = self._selector_to_tuple(variable.selector) variable_by_selector[selector_tuple] = variable - # Important: files: list[File] = [] for draft_var in draft_vars: @@ -109,7 +104,6 @@ class DraftVarLoader(VariableLoader): with Session(bind=self._engine) as session: storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) storage_key_loader.load_storage_keys(files) - return list(variable_by_selector.values()) @@ -132,7 +126,6 @@ class WorkflowDraftVariableService: assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) - # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), # PostgreSQL can efficiently retrieve the results using a bitmap index scan. @@ -159,7 +152,6 @@ class WorkflowDraftVariableService: .offset((page - 1) * limit) .all() ) - return WorkflowDraftVariableList(variables=variables, total=total) def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: @@ -220,7 +212,6 @@ class WorkflowDraftVariableService: def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: conv_var_by_name = {i.name: i for i in workflow.conversation_variables} conv_var = conv_var_by_name.get(variable.name) - if conv_var is None: self._session.delete(instance=variable) self._session.flush() @@ -228,7 +219,6 @@ class WorkflowDraftVariableService: "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name ) return None - variable.set_value(conv_var) variable.last_edited_at = None self._session.add(variable) @@ -247,7 +237,6 @@ class WorkflowDraftVariableService: self._session.flush() _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) node_exec = self._session.scalars(query).first() if node_exec is None: @@ -260,16 +249,13 @@ class WorkflowDraftVariableService: self._session.delete(instance=variable) self._session.flush() return None - outputs_dict = node_exec.outputs_dict or {} # a sentinel value used to check the absent of the output variable key. absent = object() - if variable.get_variable_type() == DraftVariableType.NODE: # Get node type for proper value extraction node_config = workflow.get_node_config_by_id(variable.node_id) node_type = workflow.get_node_type_from_node_config(node_config) - # Note: Based on the implementation in `_build_from_variable_assigner_mapping`, # VariableAssignerNode (both v1 and v2) can only create conversation draft variables. # For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes. @@ -281,7 +267,6 @@ class WorkflowDraftVariableService: output_value = outputs_dict.get(variable.name, absent) else: output_value = outputs_dict.get(f"sys.{variable.name}", absent) - # We cannot use `is None` to check the existence of an output variable here as # the value of the output may be `None`. if output_value is absent: @@ -348,15 +333,12 @@ class WorkflowDraftVariableService: ) -> str: """ get_or_create_conversation creates and returns the ID of a conversation for debugging. - If a conversation already exists, as determined by the following criteria, its ID is returned: - The system variable `sys.conversation_id` exists in the draft variable table, and - A corresponding conversation record is found in the database. - If no such conversation exists, a new conversation is created and its ID is returned. """ conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) - if conv_id is not None: conversation = ( self._session.query(Conversation) @@ -387,7 +369,6 @@ class WorkflowDraftVariableService: from_end_user_id=None, from_account_id=account_id, ) - self._session.add(conversation) self._session.flush() return conversation.id @@ -494,11 +475,9 @@ def _build_segment_for_serialized_values(v: Any) -> Segment: """ Reconstructs Segment objects from serialized values, with special handling for FileSegment and ArrayFileSegment types. - This function should only be used when: 1. No explicit type information is available 2. The input value is in serialized form (dict or list) - It detects potential file objects in the serialized data and properly rebuilds the appropriate segment type. """ @@ -512,7 +491,6 @@ class DraftVariableSaver: # This is used to signal the execution of a workflow node when it has no other outputs. _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__" _DUMMY_OUTPUT_VALUE: ClassVar[None] = None - # _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that # should be excluded when saving draft variables. This prevents certain internal or # technical variables from being exposed in the draft environment, particularly those @@ -521,23 +499,17 @@ class DraftVariableSaver: NodeType.LLM: frozenset(["finish_reason"]), NodeType.LOOP: frozenset(["loop_round"]), } - # Database session used for persisting draft variables. _session: Session - # The application ID associated with the draft variables. # This should match the `Workflow.app_id` of the workflow to which the current node belongs. _app_id: str - # The ID of the node for which DraftVariableSaver is saving output variables. _node_id: str - # The type of the current node (see NodeType). _node_type: NodeType - # _node_execution_id: str - # _enclosing_node_id identifies the container node that the current node belongs to. # For example, if the current node is an LLM node inside an Iteration node # or Loop node, then `_enclosing_node_id` refers to the ID of @@ -586,7 +558,6 @@ class DraftVariableSaver: def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars: list[WorkflowDraftVariable] = [] updated_variables = get_updated_variables(process_data) or [] - for item in updated_variables: selector = item.selector if len(selector) < MIN_SELECTORS_LENGTH: @@ -638,7 +609,6 @@ class DraftVariableSaver: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: value_seg = ArrayFileSegment(value=[]) - draft_vars.append( WorkflowDraftVariable.new_sys_variable( app_id=self._app_id, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d3086..bdbde43427 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -23,7 +23,6 @@ class WorkflowRunService: """ Get advanced chat app workflow run list Only return triggered_from == advanced_chat - :param app_model: app model :param args: request args """ @@ -39,7 +38,6 @@ class WorkflowRunService: return getattr(self._workflow_run, item) pagination = self.get_paginate_workflow_runs(app_model, args) - with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message @@ -47,9 +45,7 @@ class WorkflowRunService: if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id - with_message_workflow_runs.append(with_message_workflow_run) - pagination.data = with_message_workflow_runs return pagination @@ -57,26 +53,21 @@ class WorkflowRunService: """ Get debug workflow run list Only return triggered_from == debugging - :param app_model: app model :param args: request args """ limit = int(args.get("limit", 20)) - base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get("last_id"): last_workflow_run = base_query.filter( WorkflowRun.id == args.get("last_id"), ).first() - if not last_workflow_run: raise ValueError("Last workflow run not exists") - workflow_runs = ( base_query.filter( WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id @@ -87,7 +78,6 @@ class WorkflowRunService: ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() - has_more = False if len(workflow_runs) == limit: current_page_first_workflow_run = workflow_runs[-1] @@ -95,16 +85,13 @@ class WorkflowRunService: WorkflowRun.created_at < current_page_first_workflow_run.created_at, WorkflowRun.id != current_page_first_workflow_run.id, ).count() - if rest_count > 0: has_more = True - return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail - :param app_model: app model :param run_id: workflow run id """ @@ -117,7 +104,6 @@ class WorkflowRunService: ) .first() ) - return workflow_run def get_workflow_run_node_executions( @@ -130,24 +116,19 @@ class WorkflowRunService: Get workflow run node execution list """ workflow_run = self.get_workflow_run(app_model, run_id) - contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=db.engine, user=user, app_id=app_model.id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - # Use the repository to get the database models directly order_config = OrderConfig(order_by=["index"], order_direction="desc") workflow_node_executions = repository.get_db_models_by_workflow_run( workflow_run_id=run_id, order_config=order_config ) - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51c..846c31589a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -96,7 +96,6 @@ class WorkflowService: ) .first() ) - # return draft workflow return workflow @@ -121,10 +120,8 @@ class WorkflowService: """ Get published workflow """ - if not app_model.workflow_id: return None - # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) @@ -135,7 +132,6 @@ class WorkflowService: ) .first() ) - return workflow def get_all_published_workflow( @@ -153,7 +149,6 @@ class WorkflowService: """ if not app_model.workflow_id: return [], False - stmt = ( select(Workflow) .where(Workflow.app_id == app_model.id) @@ -161,19 +156,14 @@ class WorkflowService: .limit(limit + 1) .offset((page - 1) * limit) ) - if user_id: stmt = stmt.where(Workflow.created_by == user_id) - if named_only: stmt = stmt.where(Workflow.marked_name != "") - workflows = session.scalars(stmt).all() - has_more = len(workflows) > limit if has_more: workflows = workflows[:-1] - return workflows, has_more def sync_draft_workflow( @@ -193,13 +183,10 @@ class WorkflowService: """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) - if workflow and workflow.unique_hash != unique_hash: raise WorkflowHashNotEqualError() - # validate features structure self.validate_features_structure(app_model=app_model, features=features) - # create draft workflow if not found if not workflow: workflow = Workflow( @@ -222,13 +209,10 @@ class WorkflowService: workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - # commit db session changes db.session.commit() - # trigger app workflow events app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow) - # return draft workflow return workflow @@ -249,7 +233,6 @@ class WorkflowService: draft_workflow = session.scalar(draft_workflow_stmt) if not draft_workflow: raise ValueError("No valid workflow found.") - # create new workflow workflow = Workflow.new( tenant_id=app_model.tenant_id, @@ -264,13 +247,10 @@ class WorkflowService: marked_name=marked_name, marked_comment=marked_comment, ) - # commit db session changes session.add(workflow) - # trigger app workflow events app_published_workflow_was_updated.send(app_model, published_workflow=workflow) - # return new workflow return workflow @@ -285,7 +265,6 @@ class WorkflowService: default_config = node_class.get_default_config() if default_config: default_block_configs.append(default_config) - return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: @@ -296,16 +275,13 @@ class WorkflowService: :return: """ node_type_enum = NodeType(node_type) - # return default block config if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: return None - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] default_config = node_class.get_default_config(filters=filters) if not default_config: return None - return default_config def run_draft_workflow_node( @@ -322,11 +298,9 @@ class WorkflowService: Run draft workflow node """ files = files or [] - with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) - node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) node_data = node_config.get("data", {}) @@ -354,7 +328,6 @@ class WorkflowService: node_type=node_type, conversation_id=conversation_id, ) - else: variable_pool = VariablePool( system_variables={}, @@ -362,19 +335,16 @@ class WorkflowService: environment_variables=draft_workflow.environment_variables, conversation_variables=[], ) - variable_loader = DraftVarLoader( engine=db.engine, app_id=app_model.id, tenant_id=app_model.tenant_id, ) - eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) if eclosing_node_type_and_id: _, enclosing_node_id = eclosing_node_type_and_id else: enclosing_node_id = None - run = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, @@ -383,7 +353,6 @@ class WorkflowService: variable_pool=variable_pool, variable_loader=variable_loader, ) - # run draft workflow node start_at = time.perf_counter() node_execution = self._handle_node_run_result( @@ -391,10 +360,8 @@ class WorkflowService: start_at=start_at, node_id=node_id, ) - # Set workflow_id on the NodeExecution node_execution.workflow_id = draft_workflow.id - # Create repository and save the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=db.engine, @@ -403,10 +370,8 @@ class WorkflowService: triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution = repository.to_db_model(node_execution) - with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( session=session, @@ -428,7 +393,6 @@ class WorkflowService: """ # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, @@ -440,7 +404,6 @@ class WorkflowService: start_at=start_at, node_id=node_id, ) - return workflow_node_execution def _handle_node_run_result( @@ -451,16 +414,13 @@ class WorkflowService: ) -> WorkflowNodeExecution: try: node_instance, generator = invoke_node_fn() - node_run_result: NodeRunResult | None = None for event in generator: if isinstance(event, RunCompletedEvent): node_run_result = event.run_result - # sign output files # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break - if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return @@ -498,7 +458,6 @@ class WorkflowService: run_succeeded = False node_run_result = None error = e.error - # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( id=str(uuid4()), @@ -511,7 +470,6 @@ class WorkflowService: created_at=datetime.now(UTC).replace(tzinfo=None), finished_at=datetime.now(UTC).replace(tzinfo=None), ) - if run_succeeded and node_run_result: # Set inputs, process_data, and outputs as dictionaries (not JSON strings) inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None @@ -521,12 +479,10 @@ class WorkflowService: else None ) outputs = node_run_result.outputs - node_execution.inputs = inputs node_execution.process_data = process_data node_execution.outputs = outputs node_execution.metadata = node_run_result.metadata - # Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED @@ -537,14 +493,12 @@ class WorkflowService: # Set failed status and error node_execution.status = WorkflowNodeExecutionStatus.FAILED node_execution.error = error - return node_execution def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: """ Basic mode of chatbot app(expert mode) to workflow Completion App to Workflow App - :param app_model: App instance :param account: Account instance :param args: dict @@ -552,10 +506,8 @@ class WorkflowService: """ # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") - # convert to workflow new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, @@ -565,7 +517,6 @@ class WorkflowService: icon=args.get("icon", "🤖"), icon_background=args.get("icon_background", "#FFEAD5"), ) - return new_app def validate_features_structure(self, app_model: App, features: dict) -> dict: @@ -585,7 +536,6 @@ class WorkflowService: ) -> Optional[Workflow]: """ Update workflow attributes - :param session: SQLAlchemy database session :param workflow_id: Workflow ID :param tenant_id: Tenant ID @@ -595,25 +545,19 @@ class WorkflowService: """ stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) - if not workflow: return None - allowed_fields = ["marked_name", "marked_comment"] - for field, value in data.items(): if field in allowed_fields: setattr(workflow, field, value) - workflow.updated_by = account_id workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) - return workflow def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: """ Delete a workflow - :param session: SQLAlchemy database session :param workflow_id: Workflow ID :param tenant_id: Tenant ID @@ -624,21 +568,17 @@ class WorkflowService: """ stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) - if not workflow: raise ValueError(f"Workflow with ID {workflow_id} not found") - # Check if workflow is a draft version if workflow.version == "draft": raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") - # Check if this workflow is currently referenced by an app app_stmt = select(App).where(App.workflow_id == workflow_id) app = session.scalar(app_stmt) if app: # Cannot delete a workflow that's currently in use by an app raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") - # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version tool_provider = ( @@ -650,11 +590,9 @@ class WorkflowService: ) .first() ) - if tool_provider: # Cannot delete a workflow that's published as a tool raise WorkflowInUseError("Cannot delete workflow that is published as a tool") - session.delete(workflow) return True @@ -682,7 +620,6 @@ def _setup_variable_pool( # Randomly generated. SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), } - # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: system_inputs.update( @@ -694,7 +631,6 @@ def _setup_variable_pool( ) else: system_inputs = {} - # init variable pool variable_pool = VariablePool( system_variables=system_inputs, @@ -702,7 +638,6 @@ def _setup_variable_pool( environment_variables=workflow.environment_variables, conversation_variables=conversation_variables, ) - return variable_pool @@ -710,7 +645,6 @@ def _rebuild_file_for_user_inputs_in_start_node( tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any] ) -> Mapping[str, Any]: inputs_copy = dict(user_inputs) - for variable in start_node_data.variables: if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST): continue diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 125e0c1b1e..32c7663ae9 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -21,7 +21,6 @@ class WorkspaceService: "trial_end_reason": None, "role": "normal", } - # Get role of user tenant_account_join = ( db.session.query(TenantAccountJoin) @@ -30,9 +29,7 @@ class WorkspaceService: ) assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL replace_webapp_logo = ( @@ -41,10 +38,8 @@ class WorkspaceService: else None ) remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) - tenant_info["custom_config"] = { "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } - return tenant_info diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py b/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py index 8b13789179..e69de29bb2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py @@ -1 +0,0 @@ - diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py index 8b13789179..e69de29bb2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py @@ -1 +0,0 @@ - diff --git a/core/prompt/entities/advanced_prompt_entities.py b/core/prompt/entities/advanced_prompt_entities.py index 2c56b82465..036a38c651 100644 --- a/core/prompt/entities/advanced_prompt_entities.py +++ b/core/prompt/entities/advanced_prompt_entities.py @@ -2,12 +2,10 @@ from typing import Optional, Any from pydantic import BaseModel, Field, model_validator from core.prompt.entities.role_prefix import RolePrefix from core.prompt.entities.window import Window - class MemoryConfig(BaseModel): role_prefix: RolePrefix = Field(default_factory=RolePrefix) window: Window = Field(default_factory=Window) memory_key: Optional[str] = Field(None) - # The `model_validate` method is used to create a `MemoryConfig` object from a dictionary. @model_validator(mode="before") @classmethod diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 63d0cbaf3a..53e7e7d354 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -1,7 +1,6 @@ import yaml # type: ignore from dotenv import dotenv_values from pathlib import Path - BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "APP_MAX_EXECUTION_TIME", "BATCH_UPLOAD_LIMIT", @@ -39,7 +38,6 @@ BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "WEAVIATE_BATCH_SIZE", "WEAVIATE_GRPC_ENABLED", } - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { "BATCH_UPLOAD_LIMIT", "CELERY_BEAT_SCHEDULER_TIME", @@ -87,15 +85,11 @@ BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { "WEAVIATE_BATCH_SIZE", "WEAVIATE_GRPC_ENABLED", } - API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys()) DOCKER_COMPOSE_CONFIG_SET = set() - with open(Path("docker") / Path("docker-compose.yaml")) as f: DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys()) - - def test_yaml_config(): # python set == operator is used to compare two sets DIFF_API_WITH_DOCKER = ( @@ -117,7 +111,5 @@ def test_yaml_config(): ) raise Exception("API and Docker Compose config sets are different") print("All tests passed!") - - if __name__ == "__main__": - test_yaml_config() + test_yaml_config() \ No newline at end of file diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 0b1885755b..e62f1bc660 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -89,7 +89,7 @@ services: SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} PPROF_ENABLED: ${PLUGIN_PPROF_ENABLED:-false} - DIFY_INNER_API_URL: ${PLUGIN_DIFY_INNER_API_URL:-http://host.docker.internal:5001} + DIFY_INNER_API_URL: http://172.17.0.1:5001 DIFY_INNER_API_KEY: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index 6fa9d190e5..6ef0017fee 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1 +1 @@ -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient \ No newline at end of file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index d885dc6fb7..cbd918c9ff 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,59 +1,43 @@ import json - import requests - - class DifyClient: def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"): self.api_key = api_key self.base_url = base_url - def _send_request(self, method, endpoint, json=None, params=None, stream=False): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - url = f"{self.base_url}{endpoint}" response = requests.request( method, url, json=json, params=params, headers=headers, stream=stream ) - return response - def _send_request_with_files(self, method, endpoint, data, files): headers = {"Authorization": f"Bearer {self.api_key}"} - url = f"{self.base_url}{endpoint}" response = requests.request( method, url, data=data, headers=headers, files=files ) - return response - def message_feedback(self, message_id, rating, user): data = {"rating": rating, "user": user} return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - def get_application_parameters(self, user): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - def file_upload(self, user, files): data = {"user": user} return self._send_request_with_files( "POST", "/files/upload", data=data, files=files ) - def text_to_audio(self, text: str, user: str, streaming: bool = False): data = {"text": text, "user": user, "streaming": streaming} return self._send_request("POST", "/text-to-audio", json=data) - def get_meta(self, user): params = {"user": user} return self._send_request("GET", "/meta", params=params) - - class CompletionClient(DifyClient): def create_completion_message(self, inputs, response_mode, user, files=None): data = { @@ -68,8 +52,6 @@ class CompletionClient(DifyClient): data, stream=True if response_mode == "streaming" else False, ) - - class ChatClient(DifyClient): def create_chat_message( self, @@ -89,42 +71,34 @@ class ChatClient(DifyClient): } if conversation_id: data["conversation_id"] = conversation_id - return self._send_request( "POST", "/chat-messages", data, stream=True if response_mode == "streaming" else False, ) - def get_suggested(self, message_id, user: str): params = {"user": user} return self._send_request( "GET", f"/messages/{message_id}/suggested", params=params ) - def stop_message(self, task_id, user): data = {"user": user} return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - def get_conversations(self, user, last_id=None, limit=None, pinned=None): params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} return self._send_request("GET", "/conversations", params=params) - def get_conversation_messages( self, user, conversation_id=None, first_id=None, limit=None ): params = {"user": user} - if conversation_id: params["conversation_id"] = conversation_id if first_id: params["first_id"] = first_id if limit: params["limit"] = limit - return self._send_request("GET", "/messages", params=params) - def rename_conversation( self, conversation_id, name, auto_generate: bool, user: str ): @@ -132,32 +106,24 @@ class ChatClient(DifyClient): return self._send_request( "POST", f"/conversations/{conversation_id}/name", data ) - def delete_conversation(self, conversation_id, user): data = {"user": user} return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - def audio_to_text(self, audio_file, user): data = {"user": user} files = {"audio_file": audio_file} return self._send_request_with_files("POST", "/audio-to-text", data, files) - - class WorkflowClient(DifyClient): def run( self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123" ): data = {"inputs": inputs, "response_mode": response_mode, "user": user} return self._send_request("POST", "/workflows/run", data) - def stop(self, task_id, user): data = {"user": user} return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - def get_result(self, workflow_run_id): return self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - class KnowledgeBaseClient(DifyClient): def __init__( self, @@ -167,7 +133,6 @@ class KnowledgeBaseClient(DifyClient): ): """ Construct a KnowledgeBaseClient object. - Args: api_key (str): API key of Dify. base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. @@ -176,26 +141,21 @@ class KnowledgeBaseClient(DifyClient): """ super().__init__(api_key=api_key, base_url=base_url) self.dataset_id = dataset_id - def _get_dataset_id(self): if self.dataset_id is None: raise ValueError("dataset_id is not set") return self.dataset_id - def create_dataset(self, name: str, **kwargs): return self._send_request("POST", "/datasets", {"name": name}, **kwargs) - def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): return self._send_request( "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text( self, name, text, extra_params: dict | None = None, **kwargs ): """ Create a document by text. - :param name: Name of the document :param text: Text content of the document :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) @@ -228,13 +188,11 @@ class KnowledgeBaseClient(DifyClient): data.update(extra_params) url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" return self._send_request("POST", url, json=data, **kwargs) - def update_document_by_text( self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. - :param document_id: ID of the document :param name: Name of the document :param text: Text content of the document @@ -265,13 +223,11 @@ class KnowledgeBaseClient(DifyClient): f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" ) return self._send_request("POST", url, json=data, **kwargs) - def create_document_by_file( self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. - :param file_path: Path to the file :param original_document_id: pass this ID if you want to replace the original document (optional) :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) @@ -307,13 +263,11 @@ class KnowledgeBaseClient(DifyClient): return self._send_request_with_files( "POST", url, {"data": json.dumps(data)}, files ) - def update_document_by_file( self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. - :param document_id: ID of the document :param file_path: Path to the file :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) @@ -346,36 +300,29 @@ class KnowledgeBaseClient(DifyClient): return self._send_request_with_files( "POST", url, {"data": json.dumps(data)}, files ) - def batch_indexing_status(self, batch_id: str, **kwargs): """ Get the status of the batch indexing. - :param batch_id: ID of the batch uploading :return: Response from the API """ url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" return self._send_request("GET", url, **kwargs) - def delete_dataset(self): """ Delete this dataset. - :return: Response from the API """ url = f"/datasets/{self._get_dataset_id()}" return self._send_request("DELETE", url) - def delete_document(self, document_id): """ Delete a document. - :param document_id: ID of the document :return: Response from the API """ url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" return self._send_request("DELETE", url) - def list_documents( self, page: int | None = None, @@ -385,7 +332,6 @@ class KnowledgeBaseClient(DifyClient): ): """ Get a list of documents in this dataset. - :return: Response from the API """ params = {} @@ -397,11 +343,9 @@ class KnowledgeBaseClient(DifyClient): params["keyword"] = keyword url = f"/datasets/{self._get_dataset_id()}/documents" return self._send_request("GET", url, params=params, **kwargs) - def add_segments(self, document_id, segments, **kwargs): """ Add segments to a document. - :param document_id: ID of the document :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] :return: Response from the API @@ -409,7 +353,6 @@ class KnowledgeBaseClient(DifyClient): data = {"segments": segments} url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" return self._send_request("POST", url, json=data, **kwargs) - def query_segments( self, document_id, @@ -419,7 +362,6 @@ class KnowledgeBaseClient(DifyClient): ): """ Query segments in this document. - :param document_id: ID of the document :param keyword: query keyword, optional :param status: status of the segment, optional, e.g. completed @@ -433,22 +375,18 @@ class KnowledgeBaseClient(DifyClient): if "params" in kwargs: params.update(kwargs["params"]) return self._send_request("GET", url, params=params, **kwargs) - def delete_document_segment(self, document_id, segment_id): """ Delete a segment from a document. - :param document_id: ID of the document :param segment_id: ID of the segment :return: Response from the API """ url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("DELETE", url) - def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): """ Update a segment in a document. - :param document_id: ID of the document :param segment_id: ID of the segment :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} @@ -456,4 +394,4 @@ class KnowledgeBaseClient(DifyClient): """ data = {"segment": segment_data} url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("POST", url, json=data, **kwargs) + return self._send_request("POST", url, json=data, **kwargs) \ No newline at end of file diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index 7340fffb4c..26b7d2496e 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -1,8 +1,6 @@ from setuptools import setup - with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() - setup( name="dify-client", version="0.1.12", @@ -23,4 +21,4 @@ setup( install_requires=["requests"], keywords="dify nlp ai language-processing", include_package_data=True, -) +) \ No newline at end of file diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 52032417c0..5b35ea4029 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -1,20 +1,16 @@ import os import time import unittest - from dify_client.client import ( ChatClient, CompletionClient, DifyClient, KnowledgeBaseClient, ) - API_KEY = os.environ.get("API_KEY") APP_ID = os.environ.get("APP_ID") API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") FILE_PATH_BASE = os.path.dirname(__file__) - - class TestKnowledgeBaseClient(unittest.TestCase): def setUp(self): self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) @@ -25,20 +21,17 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.document_id = None self.segment_id = None self.batch_id = None - def _get_dataset_kb_client(self): self.assertIsNotNone(self.dataset_id) return KnowledgeBaseClient( API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id ) - def test_001_create_dataset(self): response = self.knowledge_base_client.create_dataset(name="test_dataset") data = response.json() self.assertIn("id", data) self.dataset_id = data["id"] self.assertEqual("test_dataset", data["name"]) - # the following tests require to be executed in order because they use # the dataset/document/segment ids from the previous test self._test_002_list_datasets() @@ -58,13 +51,11 @@ class TestKnowledgeBaseClient(unittest.TestCase): self._test_012_update_document_segment() self._test_013_delete_document_segment() self._test_014_delete_dataset() - def _test_002_list_datasets(self): response = self.knowledge_base_client.list_datasets() data = response.json() self.assertIn("data", data) self.assertIn("total", data) - def _test_003_create_document_by_text(self): client = self._get_dataset_kb_client() response = client.create_document_by_text("test_document", "test_text") @@ -72,7 +63,6 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertIn("document", data) self.document_id = data["document"]["id"] self.batch_id = data["batch"] - def _test_004_update_document_by_text(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) @@ -83,13 +73,11 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertIn("document", data) self.assertIn("batch", data) self.batch_id = data["batch"] - def _test_005_batch_indexing_status(self): client = self._get_dataset_kb_client() response = client.batch_indexing_status(self.batch_id) response.json() self.assertEqual(response.status_code, 200) - def _test_006_update_document_by_file(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) @@ -100,13 +88,11 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertIn("document", data) self.assertIn("batch", data) self.batch_id = data["batch"] - def _test_007_list_documents(self): client = self._get_dataset_kb_client() response = client.list_documents() data = response.json() self.assertIn("data", data) - def _test_008_delete_document(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) @@ -114,7 +100,6 @@ class TestKnowledgeBaseClient(unittest.TestCase): data = response.json() self.assertIn("result", data) self.assertEqual("success", data["result"]) - def _test_009_create_document_by_file(self): client = self._get_dataset_kb_client() response = client.create_document_by_file(self.README_FILE_PATH) @@ -122,7 +107,6 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertIn("document", data) self.document_id = data["document"]["id"] self.batch_id = data["batch"] - def _test_010_add_segments(self): client = self._get_dataset_kb_client() response = client.add_segments( @@ -133,14 +117,12 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertGreater(len(data["data"]), 0) segment = data["data"][0] self.segment_id = segment["id"] - def _test_011_query_segments(self): client = self._get_dataset_kb_client() response = client.query_segments(self.document_id) data = response.json() self.assertIn("data", data) self.assertGreater(len(data["data"]), 0) - def _test_012_update_document_segment(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.segment_id) @@ -154,7 +136,6 @@ class TestKnowledgeBaseClient(unittest.TestCase): self.assertGreater(len(data["data"]), 0) segment = data["data"] self.assertEqual("test text segment 1 updated", segment["content"]) - def _test_013_delete_document_segment(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.segment_id) @@ -162,23 +143,18 @@ class TestKnowledgeBaseClient(unittest.TestCase): data = response.json() self.assertIn("result", data) self.assertEqual("success", data["result"]) - def _test_014_delete_dataset(self): client = self._get_dataset_kb_client() response = client.delete_dataset() self.assertEqual(204, response.status_code) - - class TestChatClient(unittest.TestCase): def setUp(self): self.chat_client = ChatClient(API_KEY) - def test_create_chat_message(self): response = self.chat_client.create_chat_message( {}, "Hello, World!", "test_user" ) self.assertIn("answer", response.text) - def test_create_chat_message_with_vision_model_by_remote_url(self): files = [ {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} @@ -187,7 +163,6 @@ class TestChatClient(unittest.TestCase): {}, "Describe the picture.", "test_user", files=files ) self.assertIn("answer", response.text) - def test_create_chat_message_with_vision_model_by_local_file(self): files = [ { @@ -200,28 +175,22 @@ class TestChatClient(unittest.TestCase): {}, "Describe the picture.", "test_user", files=files ) self.assertIn("answer", response.text) - def test_get_conversation_messages(self): response = self.chat_client.get_conversation_messages( "test_user", "your_conversation_id" ) self.assertIn("answer", response.text) - def test_get_conversations(self): response = self.chat_client.get_conversations("test_user") self.assertIn("data", response.text) - - class TestCompletionClient(unittest.TestCase): def setUp(self): self.completion_client = CompletionClient(API_KEY) - def test_create_completion_message(self): response = self.completion_client.create_completion_message( {"query": "What's the weather like today?"}, "blocking", "test_user" ) self.assertIn("answer", response.text) - def test_create_completion_message_with_vision_model_by_remote_url(self): files = [ {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} @@ -230,7 +199,6 @@ class TestCompletionClient(unittest.TestCase): {"query": "Describe the picture."}, "blocking", "test_user", files ) self.assertIn("answer", response.text) - def test_create_completion_message_with_vision_model_by_local_file(self): files = [ { @@ -243,32 +211,24 @@ class TestCompletionClient(unittest.TestCase): {"query": "Describe the picture."}, "blocking", "test_user", files ) self.assertIn("answer", response.text) - - class TestDifyClient(unittest.TestCase): def setUp(self): self.dify_client = DifyClient(API_KEY) - def test_message_feedback(self): response = self.dify_client.message_feedback( "your_message_id", "like", "test_user" ) self.assertIn("success", response.text) - def test_get_application_parameters(self): response = self.dify_client.get_application_parameters("test_user") self.assertIn("user_input_form", response.text) - def test_file_upload(self): file_path = "your_image_file_path" file_name = "panda.jpeg" mime_type = "image/jpeg" - with open(file_path, "rb") as file: files = {"file": (file_name, file, mime_type)} response = self.dify_client.file_upload("test_user", files) self.assertIn("name", response.text) - - if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py index 47c175acd7..bc1ec2cffe 100644 --- a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py +++ b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py @@ -1,6 +1,5 @@ import threading from unittest.mock import Mock, patch - from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.event_handlers.update_provider_when_message_created import ( @@ -9,38 +8,27 @@ from events.event_handlers.update_provider_when_message_created import ( ) from models.provider import ProviderType from sqlalchemy.exc import OperationalError - - class TestProviderUpdateDeadlockPrevention: """Test suite for deadlock prevention in Provider updates.""" - def setup_method(self): """Setup test fixtures.""" self.mock_message = Mock() self.mock_message.answer_tokens = 100 - self.mock_app_config = Mock() self.mock_app_config.tenant_id = "test-tenant-123" - self.mock_model_conf = Mock() self.mock_model_conf.provider = "openai" - self.mock_system_config = Mock() self.mock_system_config.current_quota_type = QuotaUnit.TOKENS - self.mock_provider_config = Mock() self.mock_provider_config.using_provider_type = ProviderType.SYSTEM self.mock_provider_config.system_configuration = self.mock_system_config - self.mock_provider_bundle = Mock() self.mock_provider_bundle.configuration = self.mock_provider_config - self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle - self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) self.mock_generate_entity.app_config = self.mock_app_config self.mock_generate_entity.model_conf = self.mock_model_conf - @patch("events.event_handlers.update_provider_when_message_created.db") def test_consolidated_handler_basic_functionality(self, mock_db): """Test that the consolidated handler performs both updates correctly.""" @@ -50,19 +38,14 @@ class TestProviderUpdateDeadlockPrevention: mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 # 1 row affected - # Call the handler handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - # Verify db.session.query was called assert mock_db.session.query.called - # Verify commit was called mock_db.session.commit.assert_called_once() - # Verify no rollback was called assert not mock_db.session.rollback.called - @patch("events.event_handlers.update_provider_when_message_created.db") def test_deadlock_retry_mechanism(self, mock_db): """Test that deadlock errors trigger retry logic.""" @@ -72,22 +55,17 @@ class TestProviderUpdateDeadlockPrevention: mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - # First call raises deadlock, second succeeds mock_db.session.commit.side_effect = [ OperationalError("deadlock detected", None, None), None, # Success on retry ] - # Call the handler handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - # Verify commit was called twice (original + retry) assert mock_db.session.commit.call_count == 2 - # Verify rollback was called once (after first failure) mock_db.session.rollback.assert_called_once() - @patch("events.event_handlers.update_provider_when_message_created.db") @patch("events.event_handlers.update_provider_when_message_created.time.sleep") def test_exponential_backoff_timing(self, mock_sleep, mock_db): @@ -98,32 +76,25 @@ class TestProviderUpdateDeadlockPrevention: mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - mock_db.session.commit.side_effect = [ OperationalError("deadlock detected", None, None), OperationalError("deadlock detected", None, None), None, # Success on third attempt ] - # Call the handler handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - # Verify sleep was called twice with increasing delays assert mock_sleep.call_count == 2 - # First delay should be around 0.1s + jitter first_delay = mock_sleep.call_args_list[0][0][0] assert 0.1 <= first_delay <= 0.3 - # Second delay should be around 0.2s + jitter second_delay = mock_sleep.call_args_list[1][0][0] assert 0.2 <= second_delay <= 0.4 - def test_concurrent_handler_execution(self): """Test that multiple handlers can run concurrently without deadlock.""" results = [] errors = [] - def run_handler(): try: with patch( @@ -134,7 +105,6 @@ class TestProviderUpdateDeadlockPrevention: mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - handle( self.mock_message, application_generate_entity=self.mock_generate_entity, @@ -142,28 +112,23 @@ class TestProviderUpdateDeadlockPrevention: results.append("success") except Exception as e: errors.append(str(e)) - # Run multiple handlers concurrently threads = [] for _ in range(5): thread = threading.Thread(target=run_handler) threads.append(thread) thread.start() - # Wait for all threads to complete for thread in threads: thread.join(timeout=5) - # Verify all handlers completed successfully assert len(results) == 5 assert len(errors) == 0 - def test_performance_stats_tracking(self): """Test that performance statistics are tracked correctly.""" # Reset stats stats = get_update_stats() initial_total = stats["total_updates"] - with patch( "events.event_handlers.update_provider_when_message_created.db" ) as mock_db: @@ -172,52 +137,42 @@ class TestProviderUpdateDeadlockPrevention: mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - # Call handler handle( self.mock_message, application_generate_entity=self.mock_generate_entity ) - # Check that stats were updated updated_stats = get_update_stats() assert updated_stats["total_updates"] == initial_total + 1 assert updated_stats["successful_updates"] >= initial_total + 1 - def test_non_chat_entity_ignored(self): """Test that non-chat entities are ignored by the handler.""" # Create a non-chat entity mock_non_chat_entity = Mock() mock_non_chat_entity.__class__.__name__ = "NonChatEntity" - with patch( "events.event_handlers.update_provider_when_message_created.db" ) as mock_db: # Call handler with non-chat entity handle(self.mock_message, application_generate_entity=mock_non_chat_entity) - # Verify no database operations were performed assert not mock_db.session.query.called assert not mock_db.session.commit.called - @patch("events.event_handlers.update_provider_when_message_created.db") def test_quota_calculation_tokens(self, mock_db): """Test quota calculation for token-based quotas.""" # Setup token-based quota self.mock_system_config.current_quota_type = QuotaUnit.TOKENS self.mock_message.answer_tokens = 150 - mock_query = Mock() mock_db.session.query.return_value = mock_query mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - # Call handler handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - # Verify update was called with token count update_calls = mock_query.update.call_args_list - # Should have at least one call with quota_used update quota_update_found = False for call in update_calls: @@ -225,24 +180,19 @@ class TestProviderUpdateDeadlockPrevention: if "quota_used" in values: quota_update_found = True break - assert quota_update_found - @patch("events.event_handlers.update_provider_when_message_created.db") def test_quota_calculation_times(self, mock_db): """Test quota calculation for times-based quotas.""" # Setup times-based quota self.mock_system_config.current_quota_type = QuotaUnit.TIMES - mock_query = Mock() mock_db.session.query.return_value = mock_query mock_query.filter.return_value = mock_query mock_query.order_by.return_value = mock_query mock_query.update.return_value = 1 - # Call handler handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - # Verify update was called assert mock_query.update.called - assert mock_db.session.commit.called + assert mock_db.session.commit.called \ No newline at end of file