From 78bc74c9283e5c915dc8e79b5637b57d4984e1b5 Mon Sep 17 00:00:00 2001 From: Piotr Biernat Date: Sat, 5 Sep 2020 00:11:57 +0200 Subject: [PATCH] Added cors and cache handlers as middleware --- main.go | 16 +++++++---- server.go | 84 ++++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 22 deletions(-) diff --git a/main.go b/main.go index dba5c1d..dd7f451 100644 --- a/main.go +++ b/main.go @@ -4,18 +4,22 @@ import ( "gopkg.in/alecthomas/kingpin.v2" ) -var ( - dir = kingpin.Flag("directory", "Path to dir which has to be served.").Required().Short('d').String() - port = kingpin.Flag("port", "Port to run at").Default("8080").String() -) - func main() { + var ( + dir = kingpin.Flag("directory", "Path to dir which has to be served.").Required().Short('d').String() + port = kingpin.Flag("port", "Port to run at").Default("8080").Short('p').String() + cors = kingpin.Flag("cors", "Add CORS headers").Short('c').StringMap() + cache = kingpin.Flag("cache", "Add Cache headers").StringMap() + ) + kingpin.Version("0.5") kingpin.Parse() s := Server{ port: ":" + *port, dirPath: *dir, + cors: *cors, + cache: *cache, } - s.Serve() + s.serve() } diff --git a/server.go b/server.go index fcde674..210f81b 100644 --- a/server.go +++ b/server.go @@ -12,43 +12,95 @@ import ( type Server struct { port string dirPath string + cors corsHeaders + cache cacheHeaders } -// Serve function -func (s *Server) Serve() { - s.initHandler() +type cacheHeaders map[string]string +type corsHeaders map[string]string + +var defCacheHdrs = cacheHeaders{ + "control": "no-cache", +} + +var defCorsHdrs = corsHeaders{ + "origin": "*", + "methods": "*", + "headers": "*", +} + +func cacheHandler(hdrs cacheHeaders, next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + for k, v := range hdrs { + w.Header().Add("Cache-"+strings.Title(k), v) + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +func corsHandler(cors corsHeaders, next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + for k, v := range cors { + w.Header().Add("Access-Control-Allow-"+strings.Title(k), v) + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +func (s *Server) serve() { + s.parseArgs() + + mux := http.NewServeMux() + mux.Handle("/", cacheHandler(s.cache, corsHandler(s.cors, s.handle()))) log.Print("Listening on", s.port) - log.Fatal(http.ListenAndServe(s.port, nil)) + log.Fatal(http.ListenAndServe(s.port, mux)) } -func (s *Server) initHandler() { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +func (s *Server) handle() http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { path := s.dirPath + r.URL.Path file, err := os.Lstat(path) if err != nil { - log.Println("File " + path + " not exists") + log.Println("File asd " + path + " not exists") http.NotFound(w, r) return } switch mode := file.Mode(); { - case mode.IsRegular(): - serveFile(w, r, path) case mode.IsDir(): - serveDir(w, r, path) + handleDir(path, w, r) + case mode.IsRegular(): + handleFile(path, w, r) } - }) + } - log.Printf("Serving %v directory\n", s.dirPath) + return http.HandlerFunc(fn) } -func serveDir(w http.ResponseWriter, r *http.Request, path string) { +func (s *Server) parseArgs() { + for k, v := range defCorsHdrs { // parse cors into headers + if s.cors[k] == "" { // if empty + s.cors[k] = v + } + } + + for k, v := range defCacheHdrs { + if s.cache[k] == "" { // if empty + s.cache[k] = v + } + } +} + +func handleDir(path string, w http.ResponseWriter, r *http.Request) { log.Println("Serving " + path + " dir...") dirList, err := ioutil.ReadDir(path) if err != nil { log.Fatalln(err) - return } var outputDirList []string @@ -56,12 +108,12 @@ func serveDir(w http.ResponseWriter, r *http.Request, path string) { fileURL := strings.TrimRight(r.RequestURI, "/") + "/" + file.Name() outputDirList = append(outputDirList, ""+file.Name()+"") } - response := strings.Join(outputDirList, "
") w.Write([]byte(string(response))) + log.Println("Served dir") } -func serveFile(w http.ResponseWriter, r *http.Request, path string) { +func handleFile(path string, w http.ResponseWriter, r *http.Request) { log.Println("Serving " + path + " file...") http.ServeFile(w, r, path)