package com.dy.common.multiDataSource; import com.alibaba.druid.pool.DruidDataSource; import com.alibaba.druid.pool.DruidDataSourceFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.context.EnvironmentAware; import org.springframework.context.annotation.ImportBeanDefinitionRegistrar; import org.springframework.core.env.Environment; import org.springframework.core.env.MapPropertySource; import org.springframework.core.env.StandardEnvironment; import org.springframework.core.type.AnnotationMetadata; import javax.sql.DataSource; import java.util.HashMap; import java.util.Map; import java.util.Properties; import java.util.Set; @Slf4j public class MultiDataSourceBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar, EnvironmentAware { /** * 默认dataSource */ private DataSource defaultDataSource; /** * 数据源map */ private Map dataSourcesMap = new HashMap<>(); @Override public void setEnvironment(Environment environment) { //读取配置文件获取更多数据源 String dsNames = environment.getProperty("spring.datasource.names"); for (String dsName : dsNames.split(",")) { dsName = dsName.trim() ; try{ final String keyNames = "spring.datasource." + dsName ; Properties properties = new Properties() ; ((StandardEnvironment) environment) .getPropertySources().stream() .forEach((propertySource) -> { if (propertySource instanceof MapPropertySource) { MapPropertySource mps = (MapPropertySource) propertySource; Set keys = mps.getSource().keySet(); for (String key : keys) { if (key.startsWith(keyNames)) { properties.put(key.replace(keyNames + ".", ""), String.valueOf(mps.getProperty(key))) ; //log.info(key.replace(keyNames + ".", "") + "=" + String.valueOf(mps.getProperty(key))); } } } }); DruidDataSource dataSource = (DruidDataSource)DruidDataSourceFactory.createDataSource(properties) ; if (dataSourcesMap.size() == 0) { defaultDataSource = dataSource; } dataSourcesMap.put(dsName, dataSource); }catch (Exception e){ log.error("创建数据源" + dsName + "异常", e); } } } @Override public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { Map targetDataSources = new HashMap(); //添加其他数据源 targetDataSources.putAll(dataSourcesMap); //创建DynamicDataSource GenericBeanDefinition beanDefinition = new GenericBeanDefinition(); beanDefinition.setBeanClass(MultiDataSource.class); beanDefinition.setSynthetic(true); MutablePropertyValues mpv = beanDefinition.getPropertyValues(); //defaultTargetDataSource 和 targetDataSources属性是 AbstractRoutingDataSource的两个属性Map mpv.addPropertyValue("defaultTargetDataSource", defaultDataSource); mpv.addPropertyValue("targetDataSources", targetDataSources); //注册到Spring容器中 registry.registerBeanDefinition("dataSource", beanDefinition); } }