diff --git a/cbsignal b/cbsignal index 1907cbe..45397c8 100755 Binary files a/cbsignal and b/cbsignal differ diff --git a/config.yaml b/config.yaml index a211c7e..8cb894a 100644 --- a/config.yaml +++ b/config.yaml @@ -1,4 +1,6 @@ +version: 1.1.0 + log: writers: file # 输出位置,有两个可选项 —— file 和 stdout。选择 file 会将日志记录到 logger_file 指定的日志文件中,选择 stdout 会将日志输出到标准输出,当然也可以两者同时选择 logger_level: WARN # 日志级别,DEBUG、INFO、WARN、ERROR、FATAL @@ -21,7 +23,19 @@ compression: level: 2 # -1: DefaultCompression 1: BestSpeed <---> 9: BestCompression activationRatio: 100 # 0 <---> 100 (%) -version: 1.0.0 +# 允许连接的域名白名单(不包含端口) Whitelist of domain names allowed to connect(Port not included) +allow_list: +# - "localhost" + +# 不允许连接的域名黑名单(不包含端口) Domain name blacklist not allowed to connect(Port not included) +block_list: +# - "localhost" + + + + + + diff --git a/main.go b/main.go index f536b5b..26fc36f 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" "net/http" + "net/url" "os" "os/signal" "strings" @@ -22,6 +23,11 @@ var ( cfg = pflag.StringP("config", "c", "", "Config file path.") newline = []byte{'\n'} space = []byte{' '} + + allowMap = make(map[string]bool) // allow list of domain + useAllowList = false + blockMap = make(map[string]bool) // block list of domain + useBlockList = false ) func init() { @@ -56,11 +62,32 @@ func init() { if err := log.InitWithConfig(&passLagerCfg); err != nil { fmt.Errorf("Initialize logger %s", err) } + + // Initialize allow list and block list + allowList := viper.GetStringSlice("allow_list") + if len(allowList) > 0 { + useAllowList = true + for _, v := range allowList { + allowMap[v] = true + } + } + blockList := viper.GetStringSlice("block_list") + if len(blockList) > 0 { + useBlockList = true + for _, v := range blockList{ + blockMap[v] = true + } + } + if useBlockList && useAllowList { + panic("Do not use allowList and blockList at the same time") + } + } //var EventsCollector *eventsCollector func wsHandler(w http.ResponseWriter, r *http.Request) { + // Upgrade connection //log.Printf("UpgradeHTTP") conn, _, _, err := ws.UpgradeHTTP(r, w) @@ -68,6 +95,23 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { return } + origin := r.Header.Get("Origin") + if origin != "" { + domain := GetDomain(origin) + log.Debugf("domain: %s", domain) + if useAllowList && !allowMap[domain] { + log.Infof("domian %s is out of allowList", domain) + //wsutil.WriteServerMessage(conn, ws.OpClose, nil) + conn.Close() + return + } else if useBlockList && blockMap[domain] { + log.Infof("domian %s is in blockList", domain) + //wsutil.WriteServerMessage(conn, ws.OpClose, nil) + conn.Close() + return + } + } + r.ParseForm() id := r.Form.Get("id") //log.Printf("id %s", id) @@ -209,4 +253,14 @@ func Exists(path string) bool { return false } return true +} + +// 获取域名(不包含端口) +func GetDomain(uri string) string { + parsed, err := url.Parse(uri) + if err != nil { + return "" + } + a := strings.Split(parsed.Host, ":") + return a[0] } \ No newline at end of file