Skip to content

Commit

Permalink
add allow_list and block_list of domain
Browse files Browse the repository at this point in the history
  • Loading branch information
snowinszu committed Sep 6, 2020
1 parent d726c8f commit a9a03d0
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
Binary file modified cbsignal
Binary file not shown.
16 changes: 15 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

version: 1.1.0

log:
writers: file # 输出位置,有两个可选项 —— file 和 stdout。选择 file 会将日志记录到 logger_file 指定的日志文件中,选择 stdout 会将日志输出到标准输出,当然也可以两者同时选择
logger_level: WARN # 日志级别,DEBUG、INFO、WARN、ERROR、FATAL
Expand All @@ -21,7 +23,19 @@ compression:
level: 2 # -1: DefaultCompression 1: BestSpeed <---> 9: BestCompression
activationRatio: 100 # 0 <---> 100 (%)

version: 1.0.0
# 允许连接的域名白名单(不包含端口) Whitelist of domain names allowed to connect(Port not included)
allow_list:
# - "localhost"

# 不允许连接的域名黑名单(不包含端口) Domain name blacklist not allowed to connect(Port not included)
block_list:
# - "localhost"









54 changes: 54 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/spf13/pflag"
"github.com/spf13/viper"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
Expand All @@ -22,6 +23,11 @@ var (
cfg = pflag.StringP("config", "c", "", "Config file path.")
newline = []byte{'\n'}
space = []byte{' '}

allowMap = make(map[string]bool) // allow list of domain
useAllowList = false
blockMap = make(map[string]bool) // block list of domain
useBlockList = false
)

func init() {
Expand Down Expand Up @@ -56,18 +62,56 @@ func init() {
if err := log.InitWithConfig(&passLagerCfg); err != nil {
fmt.Errorf("Initialize logger %s", err)
}

// Initialize allow list and block list
allowList := viper.GetStringSlice("allow_list")
if len(allowList) > 0 {
useAllowList = true
for _, v := range allowList {
allowMap[v] = true
}
}
blockList := viper.GetStringSlice("block_list")
if len(blockList) > 0 {
useBlockList = true
for _, v := range blockList{
blockMap[v] = true
}
}
if useBlockList && useAllowList {
panic("Do not use allowList and blockList at the same time")
}

}

//var EventsCollector *eventsCollector

func wsHandler(w http.ResponseWriter, r *http.Request) {

// Upgrade connection
//log.Printf("UpgradeHTTP")
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
return
}

origin := r.Header.Get("Origin")
if origin != "" {
domain := GetDomain(origin)
log.Debugf("domain: %s", domain)
if useAllowList && !allowMap[domain] {
log.Infof("domian %s is out of allowList", domain)
//wsutil.WriteServerMessage(conn, ws.OpClose, nil)
conn.Close()
return
} else if useBlockList && blockMap[domain] {
log.Infof("domian %s is in blockList", domain)
//wsutil.WriteServerMessage(conn, ws.OpClose, nil)
conn.Close()
return
}
}

r.ParseForm()
id := r.Form.Get("id")
//log.Printf("id %s", id)
Expand Down Expand Up @@ -209,4 +253,14 @@ func Exists(path string) bool {
return false
}
return true
}

// 获取域名(不包含端口)
func GetDomain(uri string) string {
parsed, err := url.Parse(uri)
if err != nil {
return ""
}
a := strings.Split(parsed.Host, ":")
return a[0]
}

0 comments on commit a9a03d0

Please sign in to comment.