1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.tikv.common.util;
19
20 import com.google.common.annotations.VisibleForTesting;
21 import com.google.common.collect.ImmutableList;
22 import io.grpc.ManagedChannel;
23 import io.grpc.netty.GrpcSslContexts;
24 import io.grpc.netty.NettyChannelBuilder;
25 import io.netty.handler.ssl.SslContext;
26 import io.netty.handler.ssl.SslContextBuilder;
27 import java.io.File;
28 import java.io.FileInputStream;
29 import java.net.URI;
30 import java.security.KeyStore;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.concurrent.ConcurrentHashMap;
34 import java.util.concurrent.Executors;
35 import java.util.concurrent.ScheduledExecutorService;
36 import java.util.concurrent.TimeUnit;
37 import java.util.concurrent.atomic.AtomicReference;
38 import java.util.concurrent.locks.ReadWriteLock;
39 import java.util.concurrent.locks.ReentrantReadWriteLock;
40 import javax.net.ssl.KeyManagerFactory;
41 import javax.net.ssl.SSLException;
42 import javax.net.ssl.TrustManagerFactory;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45 import org.tikv.common.HostMapping;
46 import org.tikv.common.pd.PDUtils;
47
48 public class ChannelFactory implements AutoCloseable {
49 private static final Logger logger = LoggerFactory.getLogger(ChannelFactory.class);
50 private static final String PUB_KEY_INFRA = "PKIX";
51
52
53
54 private final long connRecycleTime;
55 private final int maxFrameSize;
56 private final int keepaliveTime;
57 private final int keepaliveTimeout;
58 private final int idleTimeout;
59 private final CertContext certContext;
60 private final CertWatcher certWatcher;
61
62 @VisibleForTesting
63 public final ConcurrentHashMap<String, ManagedChannel> connPool = new ConcurrentHashMap<>();
64
65 private final AtomicReference<SslContextBuilder> sslContextBuilder = new AtomicReference<>();
66
67 private final ScheduledExecutorService recycler;
68
69 private final ReadWriteLock lock = new ReentrantReadWriteLock();
70
71 @VisibleForTesting
72 public static class CertWatcher implements AutoCloseable {
73 private static final Logger logger = LoggerFactory.getLogger(CertWatcher.class);
74 private final List<File> targets;
75 private final List<Long> lastReload = new ArrayList<>();
76 private final ScheduledExecutorService executorService =
77 Executors.newSingleThreadScheduledExecutor();
78 private final Runnable onChange;
79
80 public CertWatcher(long pollInterval, List<File> targets, Runnable onChange) {
81 this.targets = targets;
82 this.onChange = onChange;
83
84 for (File ignored : targets) {
85 lastReload.add(0L);
86 }
87
88 executorService.scheduleAtFixedRate(
89 this::tryReload, pollInterval, pollInterval, TimeUnit.SECONDS);
90 }
91
92
93 private void tryReload() {
94
95 try {
96 if (needReload()) {
97 onChange.run();
98 }
99 } catch (Exception e) {
100 logger.error("Failed to reload cert!", e);
101 }
102 }
103
104 private boolean needReload() {
105 boolean needReload = false;
106
107
108 for (int i = 0; i < targets.size(); i++) {
109 try {
110 long lastModified = targets.get(i).lastModified();
111 if (lastModified != lastReload.get(i)) {
112 lastReload.set(i, lastModified);
113 logger.warn("detected ssl context changes: {}", targets.get(i));
114 needReload = true;
115 }
116 } catch (Exception e) {
117 logger.error("fail to check the status of ssl context files", e);
118 }
119 }
120 return needReload;
121 }
122
123 @Override
124 public void close() {
125 executorService.shutdown();
126 }
127 }
128
129 @VisibleForTesting
130 public abstract static class CertContext {
131 public abstract SslContextBuilder createSslContextBuilder();
132 }
133
134 public static class JksContext extends CertContext {
135 private final String keyPath;
136 private final String keyPassword;
137 private final String trustPath;
138 private final String trustPassword;
139
140 public JksContext(String keyPath, String keyPassword, String trustPath, String trustPassword) {
141 this.keyPath = keyPath;
142 this.keyPassword = keyPassword;
143 this.trustPath = trustPath;
144 this.trustPassword = trustPassword;
145 }
146
147 @Override
148 public SslContextBuilder createSslContextBuilder() {
149 SslContextBuilder builder = GrpcSslContexts.forClient();
150 try {
151 if (keyPath != null && keyPassword != null) {
152 KeyStore keyStore = KeyStore.getInstance("JKS");
153 keyStore.load(new FileInputStream(keyPath), keyPassword.toCharArray());
154 KeyManagerFactory keyManagerFactory =
155 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
156 keyManagerFactory.init(keyStore, keyPassword.toCharArray());
157 builder.keyManager(keyManagerFactory);
158 }
159 if (trustPath != null && trustPassword != null) {
160 KeyStore trustStore = KeyStore.getInstance("JKS");
161 trustStore.load(new FileInputStream(trustPath), trustPassword.toCharArray());
162 TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(PUB_KEY_INFRA);
163 trustManagerFactory.init(trustStore);
164 builder.trustManager(trustManagerFactory);
165 }
166 } catch (Exception e) {
167 logger.error("JKS SSL context builder failed!", e);
168 throw new IllegalArgumentException(e);
169 }
170 return builder;
171 }
172 }
173
174 @VisibleForTesting
175 public static class OpenSslContext extends CertContext {
176 private final String trustPath;
177 private final String chainPath;
178 private final String keyPath;
179
180 public OpenSslContext(String trustPath, String chainPath, String keyPath) {
181 this.trustPath = trustPath;
182 this.chainPath = chainPath;
183 this.keyPath = keyPath;
184 }
185
186 @Override
187 public SslContextBuilder createSslContextBuilder() {
188 SslContextBuilder builder = GrpcSslContexts.forClient();
189 try {
190 if (trustPath != null) {
191 builder.trustManager(new File(trustPath));
192 }
193 if (chainPath != null && keyPath != null) {
194 builder.keyManager(new File(chainPath), new File(keyPath));
195 }
196 } catch (Exception e) {
197 logger.error("Failed to create ssl context builder", e);
198 throw new IllegalArgumentException(e);
199 }
200 return builder;
201 }
202 }
203
204 public ChannelFactory(
205 int maxFrameSize, int keepaliveTime, int keepaliveTimeout, int idleTimeout) {
206 this.maxFrameSize = maxFrameSize;
207 this.keepaliveTime = keepaliveTime;
208 this.keepaliveTimeout = keepaliveTimeout;
209 this.idleTimeout = idleTimeout;
210 this.certWatcher = null;
211 this.certContext = null;
212 this.recycler = null;
213 this.connRecycleTime = 0;
214 }
215
216 public ChannelFactory(
217 int maxFrameSize,
218 int keepaliveTime,
219 int keepaliveTimeout,
220 int idleTimeout,
221 long connRecycleTime,
222 long certReloadInterval,
223 String trustCertCollectionFilePath,
224 String keyCertChainFilePath,
225 String keyFilePath) {
226 this.maxFrameSize = maxFrameSize;
227 this.keepaliveTime = keepaliveTime;
228 this.keepaliveTimeout = keepaliveTimeout;
229 this.idleTimeout = idleTimeout;
230 this.connRecycleTime = connRecycleTime;
231 this.certContext =
232 new OpenSslContext(trustCertCollectionFilePath, keyCertChainFilePath, keyFilePath);
233 this.recycler = Executors.newSingleThreadScheduledExecutor();
234
235 File trustCert = new File(trustCertCollectionFilePath);
236 File keyCert = new File(keyCertChainFilePath);
237 File key = new File(keyFilePath);
238
239 if (certReloadInterval > 0) {
240 onCertChange();
241 this.certWatcher =
242 new CertWatcher(
243 certReloadInterval, ImmutableList.of(trustCert, keyCert, key), this::onCertChange);
244 } else {
245 this.certWatcher = null;
246 }
247 }
248
249 public ChannelFactory(
250 int maxFrameSize,
251 int keepaliveTime,
252 int keepaliveTimeout,
253 int idleTimeout,
254 long connRecycleTime,
255 long certReloadInterval,
256 String jksKeyPath,
257 String jksKeyPassword,
258 String jksTrustPath,
259 String jksTrustPassword) {
260 this.maxFrameSize = maxFrameSize;
261 this.keepaliveTime = keepaliveTime;
262 this.keepaliveTimeout = keepaliveTimeout;
263 this.idleTimeout = idleTimeout;
264 this.connRecycleTime = connRecycleTime;
265 this.certContext = new JksContext(jksKeyPath, jksKeyPassword, jksTrustPath, jksTrustPassword);
266 this.recycler = Executors.newSingleThreadScheduledExecutor();
267
268 File jksKey = new File(jksKeyPath);
269 File jksTrust = new File(jksTrustPath);
270 if (certReloadInterval > 0) {
271 onCertChange();
272 this.certWatcher =
273 new CertWatcher(
274 certReloadInterval, ImmutableList.of(jksKey, jksTrust), this::onCertChange);
275 } else {
276 this.certWatcher = null;
277 }
278 }
279
280 private void onCertChange() {
281 try {
282 SslContextBuilder newBuilder = certContext.createSslContextBuilder();
283 lock.writeLock().lock();
284 sslContextBuilder.set(newBuilder);
285
286 List<ManagedChannel> pending = new ArrayList<>(connPool.values());
287 recycler.schedule(() -> cleanExpiredConn(pending), connRecycleTime, TimeUnit.SECONDS);
288
289 connPool.clear();
290 } finally {
291 lock.writeLock().unlock();
292 }
293 }
294
295 public ManagedChannel getChannel(String address, HostMapping mapping) {
296 if (certContext != null) {
297 try {
298 lock.readLock().lock();
299 return connPool.computeIfAbsent(
300 address, key -> createChannel(sslContextBuilder.get(), address, mapping));
301 } finally {
302 lock.readLock().unlock();
303 }
304 }
305 return connPool.computeIfAbsent(address, key -> createChannel(null, address, mapping));
306 }
307
308 private ManagedChannel createChannel(
309 SslContextBuilder sslContextBuilder, String address, HostMapping mapping) {
310 URI uri, mapped;
311 try {
312 uri = PDUtils.addrToUri(address);
313 } catch (Exception e) {
314 throw new IllegalArgumentException("failed to form address " + address, e);
315 }
316 try {
317 mapped = mapping.getMappedURI(uri);
318 } catch (Exception e) {
319 throw new IllegalArgumentException("failed to get mapped address " + uri, e);
320 }
321
322
323
324 NettyChannelBuilder builder =
325 NettyChannelBuilder.forAddress(mapped.getHost(), mapped.getPort())
326 .maxInboundMessageSize(maxFrameSize)
327 .keepAliveTime(keepaliveTime, TimeUnit.SECONDS)
328 .keepAliveTimeout(keepaliveTimeout, TimeUnit.SECONDS)
329 .keepAliveWithoutCalls(true)
330 .idleTimeout(idleTimeout, TimeUnit.SECONDS);
331
332 if (sslContextBuilder == null) {
333 return builder.usePlaintext().build();
334 } else {
335 SslContext sslContext;
336 try {
337 sslContext = sslContextBuilder.build();
338 } catch (SSLException e) {
339 logger.error("create ssl context failed!", e);
340 throw new IllegalArgumentException(e);
341 }
342 return builder.sslContext(sslContext).build();
343 }
344 }
345
346 private void cleanExpiredConn(List<ManagedChannel> pending) {
347 for (ManagedChannel channel : pending) {
348 logger.info("cleaning expire channels");
349 channel.shutdownNow();
350 while (!channel.isShutdown()) {
351 try {
352 channel.awaitTermination(5, TimeUnit.SECONDS);
353 } catch (Exception e) {
354 logger.warn("recycle channels timeout:", e);
355 }
356 }
357 }
358 }
359
360 public void close() {
361 for (ManagedChannel ch : connPool.values()) {
362 ch.shutdown();
363 }
364 connPool.clear();
365
366 if (recycler != null) {
367 recycler.shutdown();
368 }
369
370 if (certWatcher != null) {
371 certWatcher.close();
372 }
373 }
374 }