package com.dy.common.multiDataSource;
|
|
import com.alibaba.druid.pool.DruidDataSource;
|
import com.alibaba.druid.pool.DruidDataSourceFactory;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.MutablePropertyValues;
|
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
|
import org.springframework.beans.factory.support.GenericBeanDefinition;
|
import org.springframework.context.EnvironmentAware;
|
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
|
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.MapPropertySource;
|
import org.springframework.core.env.StandardEnvironment;
|
import org.springframework.core.type.AnnotationMetadata;
|
|
import javax.sql.DataSource;
|
import java.util.HashMap;
|
import java.util.Map;
|
import java.util.Properties;
|
import java.util.Set;
|
|
/**
|
* SpringBoot容器启动时,针对数据源,第一步启动本类:
|
* 收集多数据源的配置,形成各数据源的定义,
|
* 把数据源的定义作为“dataSource”注册到Spring容器中
|
*/
|
@Slf4j
|
public class MultiDataSourceBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar, EnvironmentAware {
|
|
/**
|
* 默认dataSource
|
*/
|
private DataSource defaultDataSource;
|
|
/**
|
* 数据源map
|
*/
|
private Map<String, DataSource> dataSourcesMap = new HashMap<>();
|
|
|
@Override
|
public void setEnvironment(Environment environment) {
|
//读取配置文件获取更多数据源
|
String dsNames = environment.getProperty("spring.datasource.names");
|
for (String dsName : dsNames.split(",")) {
|
dsName = dsName.trim() ;
|
try{
|
final String keyNames = "spring.datasource." + dsName ;
|
Properties properties = new Properties() ;
|
((StandardEnvironment) environment)
|
.getPropertySources().stream()
|
.forEach((propertySource) -> {
|
if (propertySource instanceof MapPropertySource) {
|
MapPropertySource mps = (MapPropertySource) propertySource;
|
Set<String> keys = mps.getSource().keySet();
|
for (String key : keys) {
|
if (key.startsWith(keyNames)) {
|
properties.put(key.replace(keyNames + ".", ""), String.valueOf(mps.getProperty(key))) ;
|
//log.info(key.replace(keyNames + ".", "") + "=" + String.valueOf(mps.getProperty(key)));
|
}
|
}
|
}
|
});
|
|
DruidDataSource dataSource = (DruidDataSource)DruidDataSourceFactory.createDataSource(properties) ;
|
if (dataSourcesMap.size() == 0) {
|
defaultDataSource = dataSource;
|
}
|
dataSourcesMap.put(dsName, dataSource);
|
}catch (Exception e){
|
log.error("创建数据源" + dsName + "异常", e);
|
}
|
}
|
}
|
|
@Override
|
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
|
Map<Object, Object> targetDataSources = new HashMap<Object, Object>();
|
//添加其他数据源
|
targetDataSources.putAll(dataSourcesMap);
|
//创建DynamicDataSource
|
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
|
beanDefinition.setBeanClass(MultiDataSource.class);
|
beanDefinition.setSynthetic(true);
|
MutablePropertyValues mpv = beanDefinition.getPropertyValues();
|
//defaultTargetDataSource 和 targetDataSources属性是 AbstractRoutingDataSource的两个属性Map
|
mpv.addPropertyValue("defaultTargetDataSource", defaultDataSource);
|
mpv.addPropertyValue("targetDataSources", targetDataSources);
|
//注册到Spring容器中
|
registry.registerBeanDefinition("dataSource", beanDefinition);
|
}
|
|
}
|