View Javadoc
1   /*
2    * Copyright 2021 TiKV Project Authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   *
16   */
17  
18  package org.tikv.common.util;
19  
20  import com.google.common.annotations.VisibleForTesting;
21  import com.google.common.collect.ImmutableList;
22  import io.grpc.ManagedChannel;
23  import io.grpc.netty.GrpcSslContexts;
24  import io.grpc.netty.NettyChannelBuilder;
25  import io.netty.handler.ssl.SslContext;
26  import io.netty.handler.ssl.SslContextBuilder;
27  import java.io.File;
28  import java.io.FileInputStream;
29  import java.net.URI;
30  import java.security.KeyStore;
31  import java.util.ArrayList;
32  import java.util.List;
33  import java.util.concurrent.ConcurrentHashMap;
34  import java.util.concurrent.Executors;
35  import java.util.concurrent.ScheduledExecutorService;
36  import java.util.concurrent.TimeUnit;
37  import java.util.concurrent.atomic.AtomicReference;
38  import java.util.concurrent.locks.ReadWriteLock;
39  import java.util.concurrent.locks.ReentrantReadWriteLock;
40  import javax.net.ssl.KeyManagerFactory;
41  import javax.net.ssl.SSLException;
42  import javax.net.ssl.TrustManagerFactory;
43  import org.slf4j.Logger;
44  import org.slf4j.LoggerFactory;
45  import org.tikv.common.HostMapping;
46  import org.tikv.common.pd.PDUtils;
47  
48  public class ChannelFactory implements AutoCloseable {
49    private static final Logger logger = LoggerFactory.getLogger(ChannelFactory.class);
50    private static final String PUB_KEY_INFRA = "PKIX";
51  
52    // After `connRecycleTime` seconds elapses, the old channels will be forced to shut down,
53    // to avoid using the old context all the time including potential channel leak.
54    private final long connRecycleTime;
55    private final int maxFrameSize;
56    private final int keepaliveTime;
57    private final int keepaliveTimeout;
58    private final int idleTimeout;
59    private final CertContext certContext;
60    private final CertWatcher certWatcher;
61  
62    @VisibleForTesting
63    public final ConcurrentHashMap<String, ManagedChannel> connPool = new ConcurrentHashMap<>();
64  
65    private final AtomicReference<SslContextBuilder> sslContextBuilder = new AtomicReference<>();
66  
67    private final ScheduledExecutorService recycler;
68  
69    private final ReadWriteLock lock = new ReentrantReadWriteLock();
70  
71    @VisibleForTesting
72    public static class CertWatcher implements AutoCloseable {
73      private static final Logger logger = LoggerFactory.getLogger(CertWatcher.class);
74      private final List<File> targets;
75      private final List<Long> lastReload = new ArrayList<>();
76      private final ScheduledExecutorService executorService =
77          Executors.newSingleThreadScheduledExecutor();
78      private final Runnable onChange;
79  
80      public CertWatcher(long pollInterval, List<File> targets, Runnable onChange) {
81        this.targets = targets;
82        this.onChange = onChange;
83  
84        for (File ignored : targets) {
85          lastReload.add(0L);
86        }
87  
88        executorService.scheduleAtFixedRate(
89            this::tryReload, pollInterval, pollInterval, TimeUnit.SECONDS);
90      }
91  
92      // If any execution of the task encounters an exception, subsequent executions are suppressed.
93      private void tryReload() {
94        // Add exception handling to avoid schedule stop.
95        try {
96          if (needReload()) {
97            onChange.run();
98          }
99        } catch (Exception e) {
100         logger.error("Failed to reload cert!", e);
101       }
102     }
103 
104     private boolean needReload() {
105       boolean needReload = false;
106       // Check all the modification of the `targets`.
107       // If one of them changed, means to need reload.
108       for (int i = 0; i < targets.size(); i++) {
109         try {
110           long lastModified = targets.get(i).lastModified();
111           if (lastModified != lastReload.get(i)) {
112             lastReload.set(i, lastModified);
113             logger.warn("detected ssl context changes: {}", targets.get(i));
114             needReload = true;
115           }
116         } catch (Exception e) {
117           logger.error("fail to check the status of ssl context files", e);
118         }
119       }
120       return needReload;
121     }
122 
123     @Override
124     public void close() {
125       executorService.shutdown();
126     }
127   }
128 
129   @VisibleForTesting
130   public abstract static class CertContext {
131     public abstract SslContextBuilder createSslContextBuilder();
132   }
133 
134   public static class JksContext extends CertContext {
135     private final String keyPath;
136     private final String keyPassword;
137     private final String trustPath;
138     private final String trustPassword;
139 
140     public JksContext(String keyPath, String keyPassword, String trustPath, String trustPassword) {
141       this.keyPath = keyPath;
142       this.keyPassword = keyPassword;
143       this.trustPath = trustPath;
144       this.trustPassword = trustPassword;
145     }
146 
147     @Override
148     public SslContextBuilder createSslContextBuilder() {
149       SslContextBuilder builder = GrpcSslContexts.forClient();
150       try {
151         if (keyPath != null && keyPassword != null) {
152           KeyStore keyStore = KeyStore.getInstance("JKS");
153           keyStore.load(new FileInputStream(keyPath), keyPassword.toCharArray());
154           KeyManagerFactory keyManagerFactory =
155               KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
156           keyManagerFactory.init(keyStore, keyPassword.toCharArray());
157           builder.keyManager(keyManagerFactory);
158         }
159         if (trustPath != null && trustPassword != null) {
160           KeyStore trustStore = KeyStore.getInstance("JKS");
161           trustStore.load(new FileInputStream(trustPath), trustPassword.toCharArray());
162           TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(PUB_KEY_INFRA);
163           trustManagerFactory.init(trustStore);
164           builder.trustManager(trustManagerFactory);
165         }
166       } catch (Exception e) {
167         logger.error("JKS SSL context builder failed!", e);
168         throw new IllegalArgumentException(e);
169       }
170       return builder;
171     }
172   }
173 
174   @VisibleForTesting
175   public static class OpenSslContext extends CertContext {
176     private final String trustPath;
177     private final String chainPath;
178     private final String keyPath;
179 
180     public OpenSslContext(String trustPath, String chainPath, String keyPath) {
181       this.trustPath = trustPath;
182       this.chainPath = chainPath;
183       this.keyPath = keyPath;
184     }
185 
186     @Override
187     public SslContextBuilder createSslContextBuilder() {
188       SslContextBuilder builder = GrpcSslContexts.forClient();
189       try {
190         if (trustPath != null) {
191           builder.trustManager(new File(trustPath));
192         }
193         if (chainPath != null && keyPath != null) {
194           builder.keyManager(new File(chainPath), new File(keyPath));
195         }
196       } catch (Exception e) {
197         logger.error("Failed to create ssl context builder", e);
198         throw new IllegalArgumentException(e);
199       }
200       return builder;
201     }
202   }
203 
204   public ChannelFactory(
205       int maxFrameSize, int keepaliveTime, int keepaliveTimeout, int idleTimeout) {
206     this.maxFrameSize = maxFrameSize;
207     this.keepaliveTime = keepaliveTime;
208     this.keepaliveTimeout = keepaliveTimeout;
209     this.idleTimeout = idleTimeout;
210     this.certWatcher = null;
211     this.certContext = null;
212     this.recycler = null;
213     this.connRecycleTime = 0;
214   }
215 
216   public ChannelFactory(
217       int maxFrameSize,
218       int keepaliveTime,
219       int keepaliveTimeout,
220       int idleTimeout,
221       long connRecycleTime,
222       long certReloadInterval,
223       String trustCertCollectionFilePath,
224       String keyCertChainFilePath,
225       String keyFilePath) {
226     this.maxFrameSize = maxFrameSize;
227     this.keepaliveTime = keepaliveTime;
228     this.keepaliveTimeout = keepaliveTimeout;
229     this.idleTimeout = idleTimeout;
230     this.connRecycleTime = connRecycleTime;
231     this.certContext =
232         new OpenSslContext(trustCertCollectionFilePath, keyCertChainFilePath, keyFilePath);
233     this.recycler = Executors.newSingleThreadScheduledExecutor();
234 
235     File trustCert = new File(trustCertCollectionFilePath);
236     File keyCert = new File(keyCertChainFilePath);
237     File key = new File(keyFilePath);
238 
239     if (certReloadInterval > 0) {
240       onCertChange();
241       this.certWatcher =
242           new CertWatcher(
243               certReloadInterval, ImmutableList.of(trustCert, keyCert, key), this::onCertChange);
244     } else {
245       this.certWatcher = null;
246     }
247   }
248 
249   public ChannelFactory(
250       int maxFrameSize,
251       int keepaliveTime,
252       int keepaliveTimeout,
253       int idleTimeout,
254       long connRecycleTime,
255       long certReloadInterval,
256       String jksKeyPath,
257       String jksKeyPassword,
258       String jksTrustPath,
259       String jksTrustPassword) {
260     this.maxFrameSize = maxFrameSize;
261     this.keepaliveTime = keepaliveTime;
262     this.keepaliveTimeout = keepaliveTimeout;
263     this.idleTimeout = idleTimeout;
264     this.connRecycleTime = connRecycleTime;
265     this.certContext = new JksContext(jksKeyPath, jksKeyPassword, jksTrustPath, jksTrustPassword);
266     this.recycler = Executors.newSingleThreadScheduledExecutor();
267 
268     File jksKey = new File(jksKeyPath);
269     File jksTrust = new File(jksTrustPath);
270     if (certReloadInterval > 0) {
271       onCertChange();
272       this.certWatcher =
273           new CertWatcher(
274               certReloadInterval, ImmutableList.of(jksKey, jksTrust), this::onCertChange);
275     } else {
276       this.certWatcher = null;
277     }
278   }
279 
280   private void onCertChange() {
281     try {
282       SslContextBuilder newBuilder = certContext.createSslContextBuilder();
283       lock.writeLock().lock();
284       sslContextBuilder.set(newBuilder);
285 
286       List<ManagedChannel> pending = new ArrayList<>(connPool.values());
287       recycler.schedule(() -> cleanExpiredConn(pending), connRecycleTime, TimeUnit.SECONDS);
288 
289       connPool.clear();
290     } finally {
291       lock.writeLock().unlock();
292     }
293   }
294 
295   public ManagedChannel getChannel(String address, HostMapping mapping) {
296     if (certContext != null) {
297       try {
298         lock.readLock().lock();
299         return connPool.computeIfAbsent(
300             address, key -> createChannel(sslContextBuilder.get(), address, mapping));
301       } finally {
302         lock.readLock().unlock();
303       }
304     }
305     return connPool.computeIfAbsent(address, key -> createChannel(null, address, mapping));
306   }
307 
308   private ManagedChannel createChannel(
309       SslContextBuilder sslContextBuilder, String address, HostMapping mapping) {
310     URI uri, mapped;
311     try {
312       uri = PDUtils.addrToUri(address);
313     } catch (Exception e) {
314       throw new IllegalArgumentException("failed to form address " + address, e);
315     }
316     try {
317       mapped = mapping.getMappedURI(uri);
318     } catch (Exception e) {
319       throw new IllegalArgumentException("failed to get mapped address " + uri, e);
320     }
321 
322     // Channel should be lazy without actual connection until first call
323     // So a coarse grain lock is ok here
324     NettyChannelBuilder builder =
325         NettyChannelBuilder.forAddress(mapped.getHost(), mapped.getPort())
326             .maxInboundMessageSize(maxFrameSize)
327             .keepAliveTime(keepaliveTime, TimeUnit.SECONDS)
328             .keepAliveTimeout(keepaliveTimeout, TimeUnit.SECONDS)
329             .keepAliveWithoutCalls(true)
330             .idleTimeout(idleTimeout, TimeUnit.SECONDS);
331 
332     if (sslContextBuilder == null) {
333       return builder.usePlaintext().build();
334     } else {
335       SslContext sslContext;
336       try {
337         sslContext = sslContextBuilder.build();
338       } catch (SSLException e) {
339         logger.error("create ssl context failed!", e);
340         throw new IllegalArgumentException(e);
341       }
342       return builder.sslContext(sslContext).build();
343     }
344   }
345 
346   private void cleanExpiredConn(List<ManagedChannel> pending) {
347     for (ManagedChannel channel : pending) {
348       logger.info("cleaning expire channels");
349       channel.shutdownNow();
350       while (!channel.isShutdown()) {
351         try {
352           channel.awaitTermination(5, TimeUnit.SECONDS);
353         } catch (Exception e) {
354           logger.warn("recycle channels timeout:", e);
355         }
356       }
357     }
358   }
359 
360   public void close() {
361     for (ManagedChannel ch : connPool.values()) {
362       ch.shutdown();
363     }
364     connPool.clear();
365 
366     if (recycler != null) {
367       recycler.shutdown();
368     }
369 
370     if (certWatcher != null) {
371       certWatcher.close();
372     }
373   }
374 }