liurunyu
4 天以前 1cf88d43994ec7ec403319032a9d118b39fe3571
pipIrr-platform/pipIrr-common/src/main/java/com/dy/common/multiDataSource/MultiDataSourceBeanDefinitionRegistrar.java
New file
@@ -0,0 +1,93 @@
package com.dy.common.multiDataSource;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.pool.DruidDataSourceFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.env.Environment;
import org.springframework.core.env.MapPropertySource;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.type.AnnotationMetadata;
import javax.sql.DataSource;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
/**
 * SpringBoot容器启动时,针对数据源,第一步启动本类:
 * 收集多数据源的配置,形成各数据源的定义,
 * 把数据源的定义作为“dataSource”注册到Spring容器中
 */
@Slf4j
public class MultiDataSourceBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar,  EnvironmentAware {
    /**
     * 默认dataSource
     */
    private DataSource defaultDataSource;
    /**
     * 数据源map
     */
    private Map<String, DataSource> dataSourcesMap = new HashMap<>();
    @Override
    public void setEnvironment(Environment environment) {
        //读取配置文件获取更多数据源
        String dsNames = environment.getProperty("spring.datasource.names");
        for (String dsName : dsNames.split(",")) {
            dsName = dsName.trim() ;
            try{
                final String keyNames = "spring.datasource." + dsName ;
                Properties properties = new Properties() ;
                ((StandardEnvironment) environment)
                        .getPropertySources().stream()
                        .forEach((propertySource) -> {
                            if (propertySource instanceof MapPropertySource) {
                                MapPropertySource mps = (MapPropertySource) propertySource;
                                Set<String> keys = mps.getSource().keySet();
                                for (String key : keys) {
                                    if (key.startsWith(keyNames)) {
                                        properties.put(key.replace(keyNames + ".", ""), String.valueOf(mps.getProperty(key))) ;
                                        //log.info(key.replace(keyNames + ".", "") + "=" + String.valueOf(mps.getProperty(key)));
                                    }
                                }
                            }
                        });
                DruidDataSource dataSource = (DruidDataSource)DruidDataSourceFactory.createDataSource(properties) ;
                if (dataSourcesMap.size() == 0) {
                    defaultDataSource = dataSource;
                }
                dataSourcesMap.put(dsName, dataSource);
            }catch (Exception e){
                log.error("创建数据源" + dsName + "异常", e);
            }
        }
    }
    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        Map<Object, Object> targetDataSources = new HashMap<Object, Object>();
        //添加其他数据源
        targetDataSources.putAll(dataSourcesMap);
        //创建DynamicDataSource
        GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
        beanDefinition.setBeanClass(MultiDataSource.class);
        beanDefinition.setSynthetic(true);
        MutablePropertyValues mpv = beanDefinition.getPropertyValues();
        //defaultTargetDataSource 和 targetDataSources属性是 AbstractRoutingDataSource的两个属性Map
        mpv.addPropertyValue("defaultTargetDataSource", defaultDataSource);
        mpv.addPropertyValue("targetDataSources", targetDataSources);
        //注册到Spring容器中
        registry.registerBeanDefinition("dataSource", beanDefinition);
    }
}