Skip to content

Commit

Permalink
Merge pull request #11 from jacksgt/url-encoding
Browse files Browse the repository at this point in the history
Fix URL encoding issue for filenames with special characters
  • Loading branch information
rhnvrm authored Feb 3, 2021
2 parents 0a01d1c + bf02e3b commit c344c52
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ err := s3.FileDelete(simples3.DeleteInput{

// You can also download the file.
file, _ := s3.FileDownload(simples3.DownloadInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
Bucket: AWSBucket,
ObjectKey: "test.txt",
})
data, _ := ioutil.ReadAll(file)
Expand All @@ -63,7 +63,7 @@ file.Close()
var time, _ = time.Parse(time.RFC1123, "Fri, 24 May 2013 00:00:00 GMT")

url := s.GeneratePresignedURL(PresignedInput{
Bucket: "examplebucket",
Bucket: AWSBucket,
ObjectKey: "test.txt",
Method: "GET",
Timestamp: time,
Expand Down
70 changes: 63 additions & 7 deletions simples3.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/xml"
"errors"
Expand All @@ -15,8 +16,10 @@ import (
"io/ioutil"
"mime/multipart"
"net/http"
"regexp"
"strings"
"time"
"unicode/utf8"
)

const (
Expand Down Expand Up @@ -159,17 +162,23 @@ func (s3 *S3) getClient() *http.Client {
return s3.Client
}

func (s3 *S3) getURL(bucket string, args ...string) (uri string) {
// getURL constructs a URL for a given path, with multiple optional
// arguments as individual subfolders, based on the endpoint
// specified in s3 struct.
func (s3 *S3) getURL(path string, args ...string) (uri string) {
if len(args) > 0 {
path += "/" + strings.Join(args, "/")
}
// need to encode special characters in the path part of the URL
encodedPath := encodePath(path)

if len(s3.Endpoint) > 0 {
uri = s3.Endpoint + "/" + bucket
uri = s3.Endpoint + "/" + encodedPath
} else {
uri = fmt.Sprintf(s3.URIFormat, s3.Region, bucket)
uri = fmt.Sprintf(s3.URIFormat, s3.Region, encodedPath)
}

if len(args) > 0 {
uri = uri + "/" + strings.Join(args, "/")
}
return
return uri
}

// SetEndpoint can be used to the set a custom endpoint for
Expand All @@ -180,6 +189,11 @@ func (s3 *S3) SetEndpoint(uri string) *S3 {
if !strings.HasPrefix(uri, "http") {
uri = "https://" + uri
}

// make sure there is no trailing slash
if uri[len(uri)-1] == '/' {
uri = uri[:len(uri)-1]
}
s3.Endpoint = uri
}
return s3
Expand Down Expand Up @@ -391,3 +405,45 @@ func (s3 *S3) FileDelete(u DeleteInput) error {

return nil
}

// if object matches reserved string, no need to encode them
var reservedObjectNames = regexp.MustCompile("^[a-zA-Z0-9-_.~/]+$")

// encodePath encode the strings from UTF-8 byte representations to HTML hex escape sequences
//
// This is necessary since regular url.Parse() and url.Encode() functions do not support UTF-8
// non english characters cannot be parsed due to the nature in which url.Encode() is written
//
// This function on the other hand is a direct replacement for url.Encode() technique to support
// pretty much every UTF-8 character.
// adapted from https://github.com/minio/minio-go/blob/fe1f3855b146c1b6ce4199740d317e44cf9e85c2/pkg/s3utils/utils.go#L285
func encodePath(pathName string) string {
if reservedObjectNames.MatchString(pathName) {
return pathName
}
var encodedPathname strings.Builder
for _, s := range pathName {
if 'A' <= s && s <= 'Z' || 'a' <= s && s <= 'z' || '0' <= s && s <= '9' { // §2.3 Unreserved characters (mark)
encodedPathname.WriteRune(s)
continue
}
switch s {
case '-', '_', '.', '~', '/': // §2.3 Unreserved characters (mark)
encodedPathname.WriteRune(s)
continue
default:
len := utf8.RuneLen(s)
if len < 0 {
// if utf8 cannot convert, return the same string as is
return pathName
}
u := make([]byte, len)
utf8.EncodeRune(u, s)
for _, r := range u {
hex := hex.EncodeToString([]byte{r})
encodedPathname.WriteString("%" + strings.ToUpper(hex))
}
}
}
return encodedPathname.String()
}
113 changes: 113 additions & 0 deletions simples3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ func TestS3_FileUpload(t *testing.T) {
return
}
defer testTxt.Close()
// Note: cannot re-use the same file descriptor due to seeking!
testTxtSpecialChars, err := os.Open("testdata/test.txt")
if err != nil {
return
}
defer testTxtSpecialChars.Close()
testPng, err := os.Open("testdata/avatar.png")
if err != nil {
return
Expand Down Expand Up @@ -76,6 +82,25 @@ func TestS3_FileUpload(t *testing.T) {
},
wantErr: false,
},
{
name: "Upload special filename txt",
fields: tConfig{
AccessKey: os.Getenv("AWS_S3_ACCESS_KEY"),
SecretKey: os.Getenv("AWS_S3_SECRET_KEY"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
Region: os.Getenv("AWS_S3_REGION"),
},
args: args{
UploadInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
ObjectKey: "xyz/example file%with$special&chars(1)?.txt",
ContentType: "text/plain",
FileName: "example file%with$special&chars(1)?.txt",
Body: testTxtSpecialChars,
},
},
wantErr: false,
},
}
for _, testcase := range tests {
tt := testcase
Expand Down Expand Up @@ -159,6 +184,23 @@ func TestS3_FileDownload(t *testing.T) {
wantErr: false,
wantResponse: testPngData,
},
{
name: "txt-special-filename",
fields: tConfig{
AccessKey: os.Getenv("AWS_S3_ACCESS_KEY"),
SecretKey: os.Getenv("AWS_S3_SECRET_KEY"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
Region: os.Getenv("AWS_S3_REGION"),
},
args: args{
u: DownloadInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
ObjectKey: "xyz/example file%with$special&chars(1)?.txt",
},
},
wantErr: false,
wantResponse: testTxtData,
},
}

for _, testcase := range tests {
Expand Down Expand Up @@ -227,6 +269,22 @@ func TestS3_FileDelete(t *testing.T) {
},
wantErr: false,
},
{
name: "Delete special filename txt",
fields: tConfig{
AccessKey: os.Getenv("AWS_S3_ACCESS_KEY"),
SecretKey: os.Getenv("AWS_S3_SECRET_KEY"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
Region: os.Getenv("AWS_S3_REGION"),
},
args: args{
DeleteInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
ObjectKey: "xyz/example file%with$special&chars(1)?.txt",
},
},
wantErr: false,
},
}
for _, testcase := range tests {
tt := testcase
Expand Down Expand Up @@ -293,4 +351,59 @@ func TestCustomEndpoint(t *testing.T) {
if s3.getURL("bucket3") != "https://example.com/bucket3" {
t.Errorf("S3.SetEndpoint() got = %v", s3.Endpoint)
}

// try with trailing slash
s3.SetEndpoint("https://example.com/foobar/")
if s3.getURL("bucket4") != "https://example.com/foobar/bucket4" {
t.Errorf("S3.SetEndpoint() got = %v", s3.Endpoint)
}
}

func TestGetURL(t *testing.T) {
s3 := New("us-east-1", "AccessKey", "SuperSecretKey")

type args struct {
bucket string
params []string
}

tests := []struct {
name string
args args
want string
}{
{
name: "getURL: basic test",
args: args{
bucket: "xyz",
},
want: "https://s3.us-east-1.amazonaws.com/xyz",
},
{
name: "getURL: multiple parameters",
args: args{
bucket: "xyz",
params: []string{"hello", "world"},
},
want: "https://s3.us-east-1.amazonaws.com/xyz/hello/world",
},
{
name: "getURL: special characters",
args: args{
bucket: "xyz",
params: []string{"hello, world!", "#!@$%^&*(1).txt"},
},
want: "https://s3.us-east-1.amazonaws.com/xyz/hello%2C%20world%21/%23%21%40%24%25%5E%26%2A%281%29.txt",
},
}

for _, testcase := range tests {
tt := testcase
t.Run(tt.name, func(t *testing.T) {
url := s3.getURL(tt.args.bucket, tt.args.params...)
if url != tt.want {
t.Errorf("S3.getURL() got = %v, want %v", url, tt.want)
}
})
}
}

0 comments on commit c344c52

Please sign in to comment.