Compare commits

...

14 Commits

Author SHA1 Message Date
Mykhailo Lohachov
0c5bd5d810
Merge 359e7b8597 into 4b1c0526a7 2024-03-20 23:20:24 +09:00
Hajime Hoshi
4b1c0526a7 exp/textinput: add Field
Closes #2827
2024-03-20 23:19:32 +09:00
aoyako
359e7b8597 remove comment 2024-03-02 15:41:30 +09:00
aoyako
d1b9216ee1 update tests for right shift 2024-03-02 15:40:38 +09:00
aoyako
f02e9fd4d0 add shift type checks 2024-03-02 15:32:31 +09:00
aoyako
f44640778d add basic checks 2024-02-28 20:27:26 +09:00
aoyako
66a4b20bda remove return type for deduced int 2024-02-27 19:39:14 +09:00
aoyako
7f9d997175 add return type for type resolving 2024-02-27 19:29:39 +09:00
aoyako
c90d02f8d4 add: float->int cast tests 2024-02-26 21:06:28 +09:00
aoyako
2b7d20e7da fix: remove unnecessary branch 2024-02-26 18:17:41 +09:00
aoyako
d69bb04a56 add support for shift + assign 2024-02-26 18:00:52 +09:00
aoyako
5f61cf00e5 extend tests with right-shift op 2024-02-26 17:03:19 +09:00
aoyako
7f01f98200 add tests for binop shift 2024-02-26 17:02:02 +09:00
aoyako
fe887e2565 add typechecks for bitshifts ops 2024-02-26 16:14:55 +09:00
10 changed files with 887 additions and 179 deletions

View File

@ -40,16 +40,9 @@ const (
) )
type TextField struct { type TextField struct {
bounds image.Rectangle bounds image.Rectangle
multilines bool multilines bool
text string field textinput.Field
selectionStart int
selectionEnd int
focused bool
ch chan textinput.State
end func()
state textinput.State
} }
func NewTextField(bounds image.Rectangle, multilines bool) *TextField { func NewTextField(bounds image.Rectangle, multilines bool) *TextField {
@ -63,17 +56,13 @@ func (t *TextField) Contains(x, y int) bool {
return image.Pt(x, y).In(t.bounds) return image.Pt(x, y).In(t.bounds)
} }
func (t *TextField) SetSelectionStartByCursorPosition(x, y int) (bool, error) { func (t *TextField) SetSelectionStartByCursorPosition(x, y int) bool {
if err := t.cleanUp(); err != nil {
return false, err
}
idx, ok := t.textIndexByCursorPosition(x, y) idx, ok := t.textIndexByCursorPosition(x, y)
if !ok { if !ok {
return false, nil return false
} }
t.selectionStart = idx t.field.SetSelection(idx, idx)
t.selectionEnd = idx return true
return true, nil
} }
func (t *TextField) textIndexByCursorPosition(x, y int) (int, bool) { func (t *TextField) textIndexByCursorPosition(x, y int) (int, bool) {
@ -97,20 +86,21 @@ func (t *TextField) textIndexByCursorPosition(x, y int) (int, bool) {
var nlCount int var nlCount int
var lineStart int var lineStart int
var prevAdvance float64 var prevAdvance float64
for i, r := range t.text { txt := t.field.Text()
for i, r := range txt {
var x0, x1 int var x0, x1 int
currentAdvance := text.Advance(t.text[lineStart:i], fontFace) currentAdvance := text.Advance(txt[lineStart:i], fontFace)
if lineStart < i { if lineStart < i {
x0 = int((prevAdvance + currentAdvance) / 2) x0 = int((prevAdvance + currentAdvance) / 2)
} }
if r == '\n' { if r == '\n' {
x1 = int(math.MaxInt32) x1 = int(math.MaxInt32)
} else if i < len(t.text) { } else if i < len(txt) {
nextI := i + 1 nextI := i + 1
for !utf8.ValidString(t.text[i:nextI]) { for !utf8.ValidString(txt[i:nextI]) {
nextI++ nextI++
} }
nextAdvance := text.Advance(t.text[lineStart:nextI], fontFace) nextAdvance := text.Advance(txt[lineStart:nextI], fontFace)
x1 = int((currentAdvance + nextAdvance) / 2) x1 = int((currentAdvance + nextAdvance) / 2)
} else { } else {
x1 = int(currentAdvance) x1 = int(currentAdvance)
@ -127,146 +117,86 @@ func (t *TextField) textIndexByCursorPosition(x, y int) (int, bool) {
} }
} }
return len(t.text), true return len(txt), true
} }
func (t *TextField) Focus() { func (t *TextField) Focus() {
t.focused = true t.field.Focus()
} }
func (t *TextField) Blur() { func (t *TextField) Blur() {
t.focused = false t.field.Blur()
}
func (t *TextField) cleanUp() error {
if t.ch != nil {
select {
case state, ok := <-t.ch:
if state.Error != nil {
return state.Error
}
if ok && state.Committed {
t.text = t.text[:t.selectionStart] + state.Text + t.text[t.selectionEnd:]
t.selectionStart += len(state.Text)
t.selectionEnd = t.selectionStart
t.state = textinput.State{}
}
t.state = state
default:
break
}
}
if t.end != nil {
t.end()
t.ch = nil
t.end = nil
t.state = textinput.State{}
}
return nil
} }
func (t *TextField) Update() error { func (t *TextField) Update() error {
if !t.focused { if !t.field.IsFocused() {
// If the text field still has a session, read the last state and process it just in case.
if err := t.cleanUp(); err != nil {
return err
}
return nil return nil
} }
var processed bool x, y := t.bounds.Min.X, t.bounds.Min.Y
cx, cy := t.cursorPos()
// Text inputting can happen multiple times in one tick (1/60[s] by default). px, py := textFieldPadding()
// Handle all of them. x += cx + px
for { y += cy + py + int(fontFace.Metrics().HAscent)
if t.ch == nil { handled, err := t.field.HandleInput(x, y)
x, y := t.bounds.Min.X, t.bounds.Min.Y if err != nil {
cx, cy := t.cursorPos() return err
px, py := textFieldPadding()
x += cx + px
y += cy + py + int(fontFace.Metrics().HAscent)
t.ch, t.end = textinput.Start(x, y)
// Start returns nil for non-supported envrionments.
if t.ch == nil {
return nil
}
}
readchar:
for {
select {
case state, ok := <-t.ch:
if state.Error != nil {
return state.Error
}
processed = true
if !ok {
t.ch = nil
t.end = nil
t.state = textinput.State{}
break readchar
}
if state.Committed {
t.text = t.text[:t.selectionStart] + state.Text + t.text[t.selectionEnd:]
t.selectionStart += len(state.Text)
t.selectionEnd = t.selectionStart
t.state = textinput.State{}
continue
}
t.state = state
default:
break readchar
}
}
if t.ch == nil {
continue
}
break
} }
if handled {
if processed {
return nil return nil
} }
switch { switch {
case inpututil.IsKeyJustPressed(ebiten.KeyEnter): case inpututil.IsKeyJustPressed(ebiten.KeyEnter):
if t.multilines { if t.multilines {
t.text = t.text[:t.selectionStart] + "\n" + t.text[t.selectionEnd:] text := t.field.Text()
t.selectionStart += 1 selectionStart, selectionEnd := t.field.Selection()
t.selectionEnd = t.selectionStart text = text[:selectionStart] + "\n" + text[selectionEnd:]
selectionStart += len("\n")
selectionEnd = selectionStart
t.field.SetTextAndSelection(text, selectionStart, selectionEnd)
} }
case inpututil.IsKeyJustPressed(ebiten.KeyBackspace): case inpututil.IsKeyJustPressed(ebiten.KeyBackspace):
if t.selectionStart > 0 { text := t.field.Text()
selectionStart, selectionEnd := t.field.Selection()
if selectionStart != selectionEnd {
text = text[:selectionStart] + text[selectionEnd:]
} else if selectionStart > 0 {
// TODO: Remove a grapheme instead of a code point. // TODO: Remove a grapheme instead of a code point.
_, l := utf8.DecodeLastRuneInString(t.text[:t.selectionStart]) _, l := utf8.DecodeLastRuneInString(text[:selectionStart])
t.text = t.text[:t.selectionStart-l] + t.text[t.selectionEnd:] text = text[:selectionStart-l] + text[selectionEnd:]
t.selectionStart -= l selectionStart -= l
} }
t.selectionEnd = t.selectionStart selectionEnd = selectionStart
t.field.SetTextAndSelection(text, selectionStart, selectionEnd)
case inpututil.IsKeyJustPressed(ebiten.KeyLeft): case inpututil.IsKeyJustPressed(ebiten.KeyLeft):
if t.selectionStart > 0 { text := t.field.Text()
selectionStart, _ := t.field.Selection()
if selectionStart > 0 {
// TODO: Remove a grapheme instead of a code point. // TODO: Remove a grapheme instead of a code point.
_, l := utf8.DecodeLastRuneInString(t.text[:t.selectionStart]) _, l := utf8.DecodeLastRuneInString(text[:selectionStart])
t.selectionStart -= l selectionStart -= l
} }
t.selectionEnd = t.selectionStart t.field.SetTextAndSelection(text, selectionStart, selectionStart)
case inpututil.IsKeyJustPressed(ebiten.KeyRight): case inpututil.IsKeyJustPressed(ebiten.KeyRight):
if t.selectionEnd < len(t.text) { text := t.field.Text()
_, selectionEnd := t.field.Selection()
if selectionEnd < len(text) {
// TODO: Remove a grapheme instead of a code point. // TODO: Remove a grapheme instead of a code point.
_, l := utf8.DecodeRuneInString(t.text[t.selectionEnd:]) _, l := utf8.DecodeRuneInString(text[selectionEnd:])
t.selectionEnd += l selectionEnd += l
} }
t.selectionStart = t.selectionEnd t.field.SetTextAndSelection(text, selectionEnd, selectionEnd)
} }
if !t.multilines { if !t.multilines {
orig := t.text orig := t.field.Text()
new := strings.ReplaceAll(orig, "\n", "") new := strings.ReplaceAll(orig, "\n", "")
if new != orig { if new != orig {
t.selectionStart -= strings.Count(orig[:t.selectionStart], "\n") selectionStart, selectionEnd := t.field.Selection()
t.selectionEnd -= strings.Count(orig[:t.selectionEnd], "\n") selectionStart -= strings.Count(orig[:selectionStart], "\n")
selectionEnd -= strings.Count(orig[:selectionEnd], "\n")
t.field.SetSelection(selectionStart, selectionEnd)
} }
} }
@ -276,17 +206,20 @@ func (t *TextField) Update() error {
func (t *TextField) cursorPos() (int, int) { func (t *TextField) cursorPos() (int, int) {
var nlCount int var nlCount int
lastNLPos := -1 lastNLPos := -1
for i, r := range t.text[:t.selectionStart] { txt := t.field.TextForRendering()
selectionStart, _ := t.field.Selection()
if s, _, ok := t.field.CompositionSelection(); ok {
selectionStart += s
}
txt = txt[:selectionStart]
for i, r := range txt {
if r == '\n' { if r == '\n' {
nlCount++ nlCount++
lastNLPos = i lastNLPos = i
} }
} }
txt := t.text[lastNLPos+1 : t.selectionStart] txt = txt[lastNLPos+1:]
if t.state.Text != "" {
txt += t.state.Text[:t.state.CompositionSelectionStartInBytes]
}
x := int(text.Advance(txt, fontFace)) x := int(text.Advance(txt, fontFace))
y := nlCount * int(fontFace.Metrics().HLineGap+fontFace.Metrics().HAscent+fontFace.Metrics().HDescent) y := nlCount * int(fontFace.Metrics().HLineGap+fontFace.Metrics().HAscent+fontFace.Metrics().HDescent)
return x, y return x, y
@ -295,13 +228,14 @@ func (t *TextField) cursorPos() (int, int) {
func (t *TextField) Draw(screen *ebiten.Image) { func (t *TextField) Draw(screen *ebiten.Image) {
vector.DrawFilledRect(screen, float32(t.bounds.Min.X), float32(t.bounds.Min.Y), float32(t.bounds.Dx()), float32(t.bounds.Dy()), color.White, false) vector.DrawFilledRect(screen, float32(t.bounds.Min.X), float32(t.bounds.Min.Y), float32(t.bounds.Dx()), float32(t.bounds.Dy()), color.White, false)
var clr color.Color = color.Black var clr color.Color = color.Black
if t.focused { if t.field.IsFocused() {
clr = color.RGBA{0, 0, 0xff, 0xff} clr = color.RGBA{0, 0, 0xff, 0xff}
} }
vector.StrokeRect(screen, float32(t.bounds.Min.X), float32(t.bounds.Min.Y), float32(t.bounds.Dx()), float32(t.bounds.Dy()), 1, clr, false) vector.StrokeRect(screen, float32(t.bounds.Min.X), float32(t.bounds.Min.Y), float32(t.bounds.Dx()), float32(t.bounds.Dy()), 1, clr, false)
px, py := textFieldPadding() px, py := textFieldPadding()
if t.focused && t.selectionStart >= 0 { selectionStart, _ := t.field.Selection()
if t.field.IsFocused() && selectionStart >= 0 {
x, y := t.bounds.Min.X, t.bounds.Min.Y x, y := t.bounds.Min.X, t.bounds.Min.Y
cx, cy := t.cursorPos() cx, cy := t.cursorPos()
x += px + cx x += px + cx
@ -310,18 +244,13 @@ func (t *TextField) Draw(screen *ebiten.Image) {
vector.StrokeLine(screen, float32(x), float32(y), float32(x), float32(y+h), 1, color.Black, false) vector.StrokeLine(screen, float32(x), float32(y), float32(x), float32(y+h), 1, color.Black, false)
} }
shownText := t.text
if t.focused && t.state.Text != "" {
shownText = t.text[:t.selectionStart] + t.state.Text + t.text[t.selectionEnd:]
}
tx := t.bounds.Min.X + px tx := t.bounds.Min.X + px
ty := t.bounds.Min.Y + py ty := t.bounds.Min.Y + py
op := &text.DrawOptions{} op := &text.DrawOptions{}
op.GeoM.Translate(float64(tx), float64(ty)) op.GeoM.Translate(float64(tx), float64(ty))
op.ColorScale.ScaleWithColor(color.Black) op.ColorScale.ScaleWithColor(color.Black)
op.LineSpacing = fontFace.Metrics().HLineGap + fontFace.Metrics().HAscent + fontFace.Metrics().HDescent op.LineSpacing = fontFace.Metrics().HLineGap + fontFace.Metrics().HAscent + fontFace.Metrics().HDescent
text.Draw(screen, shownText, fontFace, op) text.Draw(screen, t.field.TextForRendering(), fontFace, op)
} }
const textFieldHeight = 24 const textFieldHeight = 24
@ -353,9 +282,7 @@ func (g *Game) Update() error {
for _, tf := range g.textFields { for _, tf := range g.textFields {
if tf.Contains(x, y) { if tf.Contains(x, y) {
tf.Focus() tf.Focus()
if _, err := tf.SetSelectionStartByCursorPosition(x, y); err != nil { tf.SetSelectionStartByCursorPosition(x, y)
return err
}
} else { } else {
tf.Blur() tf.Blur()
} }

246
exp/textinput/field.go Normal file
View File

@ -0,0 +1,246 @@
// Copyright 2024 The Ebitengine Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package textinput
import (
"sync"
)
var (
theFocusedField *Field
theFocusedFieldM sync.Mutex
)
func focusField(f *Field) {
var origField *Field
defer func() {
if origField != nil {
origField.cleanUp()
}
}()
theFocusedFieldM.Lock()
defer theFocusedFieldM.Unlock()
if theFocusedField == f {
return
}
origField = theFocusedField
theFocusedField = f
}
func blurField(f *Field) {
var origField *Field
defer func() {
if origField != nil {
origField.cleanUp()
}
}()
theFocusedFieldM.Lock()
defer theFocusedFieldM.Unlock()
if theFocusedField != f {
return
}
origField = theFocusedField
theFocusedField = nil
}
func isFieldFocused(f *Field) bool {
theFocusedFieldM.Lock()
defer theFocusedFieldM.Unlock()
return theFocusedField == f
}
// Field is a region accepting text inputting with IME.
//
// Field is not focused by default. You have to call Focus when you start text inputting.
//
// Field is a wrapper of the low-level API like Start.
//
// For an actual usage, see the examples "textinput".
type Field struct {
text string
selectionStart int
selectionEnd int
ch chan State
end func()
state State
err error
}
// HandleInput updates the field state.
// HandleInput must be called every tick, i.e., every HandleInput, when Field is focused.
// HandleInput takes a position where an IME window is shown if needed.
//
// HandleInput returns whether the text inputting is handled or not.
// If HandleInput returns true, a Field user should not handle further input events.
//
// HandleInput returns an error when handling input causes an error.
func (f *Field) HandleInput(x, y int) (handled bool, err error) {
if f.err != nil {
return false, f.err
}
if !f.IsFocused() {
return false, nil
}
// Text inputting can happen multiple times in one tick (1/60[s] by default).
// Handle all of them.
for {
if f.ch == nil {
// TODO: On iOS Safari, Start doesn't work as expected (#2898).
// Handle a click event and focus the textarea there.
f.ch, f.end = Start(x, y)
// Start returns nil for non-supported envrionments.
if f.ch == nil {
return true, nil
}
}
readchar:
for {
select {
case state, ok := <-f.ch:
if state.Error != nil {
f.err = state.Error
return false, f.err
}
handled = true
if !ok {
f.ch = nil
f.end = nil
f.state = State{}
break readchar
}
if state.Committed {
f.text = f.text[:f.selectionStart] + state.Text + f.text[f.selectionEnd:]
f.selectionStart += len(state.Text)
f.selectionEnd = f.selectionStart
f.state = State{}
continue
}
f.state = state
default:
break readchar
}
}
if f.ch == nil {
continue
}
break
}
return
}
// Focus focuses the field.
// A Field has to be focused to start text inputting.
//
// There can be only one Field that is focused at the same time.
// When Focus is called and there is already a focused field, Focus removes the focus of that.
func (f *Field) Focus() {
focusField(f)
}
// Blur removes the focus from the field.
func (f *Field) Blur() {
blurField(f)
}
// IsFocused reports whether the field is focused or not.
func (f *Field) IsFocused() bool {
return isFieldFocused(f)
}
func (f *Field) cleanUp() {
if f.err != nil {
return
}
// If the text field still has a session, read the last state and process it just in case.
if f.ch != nil {
select {
case state, ok := <-f.ch:
if state.Error != nil {
f.err = state.Error
return
}
if ok && state.Committed {
f.text = f.text[:f.selectionStart] + state.Text + f.text[f.selectionEnd:]
f.selectionStart += len(state.Text)
f.selectionEnd = f.selectionStart
f.state = State{}
}
f.state = state
default:
break
}
}
if f.end != nil {
f.end()
f.ch = nil
f.end = nil
f.state = State{}
}
}
// Selection returns the current selection range in bytes.
func (f *Field) Selection() (start, end int) {
return f.selectionStart, f.selectionEnd
}
// CompositionSelection returns the current composition selection in bytes if a text is composited.
// If a text is not composited, this returns 0s and false.
// The returned values indicate relative positions in bytes where the current composition text's start is 0.
func (f *Field) CompositionSelection() (start, end int, ok bool) {
if f.IsFocused() && f.state.Text != "" {
return f.state.CompositionSelectionStartInBytes, f.state.CompositionSelectionEndInBytes, true
}
return 0, 0, false
}
// SetSelection sets the selection range.
func (f *Field) SetSelection(start, end int) {
f.cleanUp()
f.selectionStart = start
f.selectionEnd = end
}
// Text returns the current text.
// The returned value doesn't include compositing texts.
func (f *Field) Text() string {
return f.text
}
// TextForRendering returns the text for rendering.
// The returned value includes compositing texts.
func (f *Field) TextForRendering() string {
if f.IsFocused() && f.state.Text != "" {
return f.text[:f.selectionStart] + f.state.Text + f.text[f.selectionEnd:]
}
return f.text
}
// SetTextAndSelection sets the text and the selection range.
func (f *Field) SetTextAndSelection(text string, selectionStart, selectionEnd int) {
f.cleanUp()
f.text = text
f.selectionStart = selectionStart
f.selectionEnd = selectionEnd
}

View File

@ -25,6 +25,8 @@ import (
) )
// State represents the current state of text inputting. // State represents the current state of text inputting.
//
// State is the low-level API. For most use cases, Field is easier to use.
type State struct { type State struct {
// Text represents the current inputting text. // Text represents the current inputting text.
Text string Text string
@ -45,6 +47,8 @@ type State struct {
// Start starts text inputting. // Start starts text inputting.
// Start returns a channel to send the state repeatedly, and a function to end the text inputting. // Start returns a channel to send the state repeatedly, and a function to end the text inputting.
// //
// Start is the low-leve API. For most use cases, Field is easier to use.
//
// Start returns nil and nil if the current environment doesn't support this package. // Start returns nil and nil if the current environment doesn't support this package.
func Start(x, y int) (states chan State, close func()) { func Start(x, y int) (states chan State, close func()) {
cx, cy := ui.Get().LogicalPositionToClientPositionInNativePixels(float64(x), float64(y)) cx, cy := ui.Get().LogicalPositionToClientPositionInNativePixels(float64(x), float64(y))

165
internal/shader/delayed.go Normal file
View File

@ -0,0 +1,165 @@
// Copyright 2024 The Ebiten Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shader
import (
"fmt"
gconstant "go/constant"
"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
)
type delayedTypeValidator interface {
Validate(t shaderir.Type) (shaderir.Type, bool)
IsValidated() (shaderir.Type, bool)
Error() string
}
func isArgDefaultTypeInt(f shaderir.BuiltinFunc) bool {
return f == shaderir.IntF || f == shaderir.IVec2F || f == shaderir.IVec3F || f == shaderir.IVec4F
}
func isIntType(t shaderir.Type) bool {
return t.Main == shaderir.Int || t.IsIntVector()
}
func (cs *compileState) ValidateDefaultTypesForExpr(block *block, expr shaderir.Expr, t shaderir.Type) shaderir.Type {
if check, ok := cs.delayedTypeCheks[expr.Ast]; ok {
if resT, ok := check.IsValidated(); ok {
return resT
}
resT, ok := check.Validate(t)
if !ok {
return shaderir.Type{Main: shaderir.None}
}
return resT
}
switch expr.Type {
case shaderir.LocalVariable:
return block.vars[expr.Index].typ
case shaderir.Binary:
left := expr.Exprs[0]
right := expr.Exprs[1]
leftType := cs.ValidateDefaultTypesForExpr(block, left, t)
rightType := cs.ValidateDefaultTypesForExpr(block, right, t)
// Usure about top-level type, try to validate by neighbour type
// The same work is done twice. Can it be optimized?
if t.Main == shaderir.None {
cs.ValidateDefaultTypesForExpr(block, left, rightType)
cs.ValidateDefaultTypesForExpr(block, right, leftType)
}
case shaderir.Call:
fun := expr.Exprs[0]
if fun.Type == shaderir.BuiltinFuncExpr {
if isArgDefaultTypeInt(fun.BuiltinFunc) {
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Int})
}
return shaderir.Type{Main: shaderir.Int}
}
for _, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.Float})
}
return shaderir.Type{Main: shaderir.Float}
}
if fun.Type == shaderir.FunctionExpr {
args := cs.funcs[fun.Index].ir.InParams
for i, e := range expr.Exprs[1:] {
cs.ValidateDefaultTypesForExpr(block, e, args[i])
}
retT := cs.funcs[fun.Index].ir.Return
return retT
}
}
return shaderir.Type{Main: shaderir.None}
}
func (cs *compileState) ValidateDefaultTypes(block *block, stmt shaderir.Stmt) {
switch stmt.Type {
case shaderir.Assign:
left := stmt.Exprs[0]
right := stmt.Exprs[1]
if left.Type == shaderir.LocalVariable {
varType := block.vars[left.Index].typ
// Type is not explicitly specified
if stmt.IsTypeGuessed {
varType = shaderir.Type{Main: shaderir.None}
}
cs.ValidateDefaultTypesForExpr(block, right, varType)
}
case shaderir.ExprStmt:
for _, e := range stmt.Exprs {
cs.ValidateDefaultTypesForExpr(block, e, shaderir.Type{Main: shaderir.None})
}
}
}
type delayedShiftValidator struct {
shiftType shaderir.Op
value gconstant.Value
validated bool
closestUnknown bool
failed bool
}
func (d *delayedShiftValidator) IsValidated() (shaderir.Type, bool) {
if d.failed {
return shaderir.Type{}, false
}
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
// If only matched with None
if d.closestUnknown {
// Was it originally represented by an int constant?
if d.value.Kind() == gconstant.Int {
return shaderir.Type{Main: shaderir.Int}, true
}
}
return shaderir.Type{}, false
}
func (d *delayedShiftValidator) Validate(t shaderir.Type) (shaderir.Type, bool) {
if d.validated {
return shaderir.Type{Main: shaderir.Int}, true
}
if isIntType(t) {
d.validated = true
return shaderir.Type{Main: shaderir.Int}, true
}
if t.Main == shaderir.None {
d.closestUnknown = true
return t, true
}
return shaderir.Type{Main: shaderir.None}, false
}
func (d *delayedShiftValidator) Error() string {
st := "left shift"
if d.shiftType == shaderir.RightShift {
st = "right shift"
}
return fmt.Sprintf("left operand for %s should be int", st)
}

View File

@ -36,7 +36,7 @@ func canTruncateToFloat(v gconstant.Value) bool {
var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`) var textureVariableRe = regexp.MustCompile(`\A__t(\d+)\z`)
func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) ([]shaderir.Expr, []shaderir.Type, []shaderir.Stmt, bool) { func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, markLocalVariableUsed bool) (rexpr []shaderir.Expr, rtype []shaderir.Type, rstmt []shaderir.Stmt, ok bool) {
switch e := expr.(type) { switch e := expr.(type) {
case *ast.BasicLit: case *ast.BasicLit:
switch e.Kind { switch e.Kind {
@ -105,7 +105,14 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
// Resolve untyped constants. // Resolve untyped constants.
l, r, ok := shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst) var l gconstant.Value
var r gconstant.Value
origLvalue := lhs[0].Const
if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
l, r, ok = shaderir.ResolveUntypedConstsForBitShiftOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
} else {
l, r, ok = shaderir.ResolveUntypedConstsForBinaryOp(lhs[0].Const, rhs[0].Const, lhst, rhst)
}
if !ok { if !ok {
// TODO: Show a better type name for untyped constants. // TODO: Show a better type name for untyped constants.
cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String())) cs.addError(e.Pos(), fmt.Sprintf("types don't match: %s %s %s", lhst.String(), op, rhst.String()))
@ -113,27 +120,49 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
} }
lhs[0].Const, rhs[0].Const = l, r lhs[0].Const, rhs[0].Const = l, r
// If either is typed, resolve the other type. if op2 == shaderir.LeftShift || op2 == shaderir.RightShift {
// If both are untyped, keep them untyped. if !(lhst.Main == shaderir.None && rhst.Main == shaderir.None) {
if lhst.Main != shaderir.None || rhst.Main != shaderir.None { // If both are const
if lhs[0].Const != nil { if rhs[0].Const != nil && (rhst.Main == shaderir.None || lhs[0].Const != nil) {
switch lhs[0].Const.Kind() { rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Float: }
lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int: // If left is untyped const
if lhst.Main == shaderir.None && lhs[0].Const != nil {
lhst = shaderir.Type{Main: shaderir.Int} lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool: // Left should be implicitly converted to the type it would assume if the shift expression were replaced by its left operand alone.
lhst = shaderir.Type{Main: shaderir.Bool} if rhs[0].Const == nil {
defer func() {
if ok {
cs.addDelayedTypeCheck(expr, &delayedShiftValidator{value: origLvalue})
}
}()
}
} }
} }
if rhs[0].Const != nil { } else {
switch rhs[0].Const.Kind() { // If either is typed, resolve the other type.
case gconstant.Float: // If both are untyped, keep them untyped.
rhst = shaderir.Type{Main: shaderir.Float} if lhst.Main != shaderir.None || rhst.Main != shaderir.None {
case gconstant.Int: if lhs[0].Const != nil {
rhst = shaderir.Type{Main: shaderir.Int} switch lhs[0].Const.Kind() {
case gconstant.Bool: case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Bool} lhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
lhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
lhst = shaderir.Type{Main: shaderir.Bool}
}
}
if rhs[0].Const != nil {
switch rhs[0].Const.Kind() {
case gconstant.Float:
rhst = shaderir.Type{Main: shaderir.Float}
case gconstant.Int:
rhst = shaderir.Type{Main: shaderir.Int}
case gconstant.Bool:
rhst = shaderir.Type{Main: shaderir.Bool}
}
} }
} }
} }
@ -153,6 +182,13 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
v = gconstant.MakeBool(b) v = gconstant.MakeBool(b)
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ: case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const)) v = gconstant.MakeBool(gconstant.Compare(lhs[0].Const, op, rhs[0].Const))
case token.SHL, token.SHR:
shift, ok := gconstant.Int64Val(rhs[0].Const)
if !ok {
cs.addError(e.Pos(), fmt.Sprintf("unexpected %s type for: %s", rhs[0].Const.String(), e.Op))
} else {
v = gconstant.Shift(lhs[0].Const, op, uint(shift))
}
default: default:
v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const) v = gconstant.BinaryOp(lhs[0].Const, op, rhs[0].Const)
} }
@ -169,6 +205,7 @@ func (cs *compileState) parseExpr(block *block, fname string, expr ast.Expr, mar
{ {
Type: shaderir.Binary, Type: shaderir.Binary,
Op: op2, Op: op2,
Ast: expr,
Exprs: []shaderir.Expr{lhs[0], rhs[0]}, Exprs: []shaderir.Expr{lhs[0], rhs[0]},
}, },
}, []shaderir.Type{t}, stmts, true }, []shaderir.Type{t}, stmts, true

View File

@ -61,6 +61,8 @@ type compileState struct {
varyingParsed bool varyingParsed bool
delayedTypeCheks map[ast.Expr]delayedTypeValidator
errs []string errs []string
} }
@ -82,6 +84,13 @@ func (cs *compileState) findUniformVariable(name string) (int, bool) {
return 0, false return 0, false
} }
func (cs *compileState) addDelayedTypeCheck(at ast.Expr, check delayedTypeValidator) {
if cs.delayedTypeCheks == nil {
cs.delayedTypeCheks = make(map[ast.Expr]delayedTypeValidator, 1)
}
cs.delayedTypeCheks[at] = check
}
type typ struct { type typ struct {
name string name string
ir shaderir.Type ir shaderir.Type

View File

@ -49,6 +49,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
if !ok { if !ok {
return nil, false return nil, false
} }
for i := range ss {
ss[i].IsTypeGuessed = true
}
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ASSIGN: case token.ASSIGN:
if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 { if len(stmt.Lhs) != len(stmt.Rhs) && len(stmt.Rhs) != 1 {
@ -60,7 +64,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
stmts = append(stmts, ss...) stmts = append(stmts, ss...)
case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN: case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN, token.AND_NOT_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN:
rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true) rhs, rts, ss, ok := cs.parseExpr(block, fname, stmt.Rhs[0], true)
if !ok { if !ok {
return nil, false return nil, false
@ -100,6 +104,10 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
op = shaderir.Or op = shaderir.Or
case token.XOR_ASSIGN: case token.XOR_ASSIGN:
op = shaderir.Xor op = shaderir.Xor
case token.SHL_ASSIGN:
op = shaderir.LeftShift
case token.SHR_ASSIGN:
op = shaderir.RightShift
default: default:
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected token: %s", stmt.Tok))
return nil, false return nil, false
@ -110,7 +118,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator / not defined on %s", rts[0].String()))
return nil, false return nil, false
} }
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() { if lts[0].Main != shaderir.Int && !lts[0].IsIntVector() {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} }
@ -137,7 +145,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
} }
} }
case shaderir.Float: case shaderir.Float:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if rhs[0].Const != nil && } else if rhs[0].Const != nil &&
(rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) && (rts[0].Main == shaderir.None || rts[0].Main == shaderir.Float) &&
@ -148,7 +156,7 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
return nil, false return nil, false
} }
case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: case shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor { if op == shaderir.And || op == shaderir.Or || op == shaderir.Xor || op == shaderir.LeftShift || op == shaderir.RightShift {
cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String())) cs.addError(stmt.Pos(), fmt.Sprintf("invalid operation: operator %s not defined on %s", stmt.Tok, lts[0].String()))
} else if (op == shaderir.MatrixMul || op == shaderir.Div) && } else if (op == shaderir.MatrixMul || op == shaderir.Div) &&
(rts[0].Main == shaderir.Float || (rts[0].Main == shaderir.Float ||
@ -469,6 +477,25 @@ func (cs *compileState) parseStmt(block *block, fname string, stmt ast.Stmt, inP
cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt)) cs.addError(stmt.Pos(), fmt.Sprintf("unexpected statement: %#v", stmt))
return nil, false return nil, false
} }
// Need to run delayed checks
if len(cs.delayedTypeCheks) != 0 {
for _, st := range stmts {
cs.ValidateDefaultTypes(block, st)
}
// Collect all errors first
foundErr := false
for s, v := range cs.delayedTypeCheks {
if _, ok := v.IsValidated(); !ok {
foundErr = true
cs.addError(s.Pos(), v.Error())
}
}
if foundErr {
return nil, false
}
}
return stmts, true return stmts, true
} }

View File

@ -1314,6 +1314,263 @@ func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
} }
} }
// Issue: #2755
func TestSyntaxOperatorShift(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "b := 2.0; a := 1.0 << 2.0 == b; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0<<s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0<<s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + ivec2(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + foo_int_int(3.0<<s); _ = a", err: false},
{stmt: "s := 1; a := 2.0<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3.0<<s; _ = a", err: true},
{stmt: "s := 1; a := 2.0<<s + 3<<s; _ = a", err: true},
{stmt: "s := 1; a := 2<<s + 3<<s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2<<s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0<<s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0<<s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0<<s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0<<s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0<<s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0<<2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0<<2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0<<s + int(float(1<<s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << 2.0 << 3.0 << 4.0 << s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 << 1 << 1 << 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 << s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0<<s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1<<s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1<<s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 << s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 << s; _ = a", err: false},
{stmt: "s := 1; b := 1 << s; _ = b", err: false},
{stmt: "var a = 1; b := a << 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 << s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; a := 1 << s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 << s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 << s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 << s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 << s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 << s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 << s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 << s; _ = a", err: false},
{stmt: "var a float = 1.0 << 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 << 2; _ = a", err: false},
{stmt: "var a float = 1.0 << 2; _ = a", err: false},
{stmt: "a := 1 << 2.0; _ = a", err: false},
{stmt: "a := 1.0 << 2; _ = a", err: false},
{stmt: "a := 1.0 << 2.0; _ = a", err: false},
{stmt: "a := 1 << 2; _ = a", err: false},
{stmt: "a := float(1.0) << 2; _ = a", err: true},
{stmt: "a := 1 << float(2.0); _ = a", err: false},
{stmt: "a := 1.0 << float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) << 2; _ = a", err: false},
{stmt: "a := 1 << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << ivec3(2); _ = a", err: true},
{stmt: "a := 1 << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << 2; _ = a", err: true},
{stmt: "a := float(1.0) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) << ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) << vec3(2); _ = a", err: true},
{stmt: "b := 2.0; a := 1.0 >> 2.0 == b; _ = a", err: false},
{stmt: "s := 1; b := 2.0; a := 1.0>>s == b; _ = a", err: true},
{stmt: "s := 1; b := 2; a := 1.0>>s == b; _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + ivec2(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + vec2(3); _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + ivec2(3); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + foo_int_int(3.0>>s); _ = a", err: false},
{stmt: "s := 1; a := 2.0>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3.0>>s; _ = a", err: true},
{stmt: "s := 1; a := 2.0>>s + 3>>s; _ = a", err: true},
{stmt: "s := 1; a := 2>>s + 3>>s; _ = a", err: false},
{stmt: "s := 1; foo_multivar(0, 0, 2>>s)", err: false},
{stmt: "s := 1; foo_multivar(0, 2.0>>s, 0)", err: true},
{stmt: "s := 1; foo_multivar(2.0>>s, 0, 0)", err: false},
{stmt: "s := 1; a := foo_multivar(2.0>>s, 0, 0); _ = a", err: false},
{stmt: "s := 1; a := foo_multivar(0, 2.0>>s, 0); _ = a", err: true},
{stmt: "s := 1; a := foo_multivar(0, 0, 2.0>>s); _ = a", err: false},
{stmt: "a := foo_multivar(0, 0, 1.0>>2.0); _ = a", err: false},
{stmt: "a := foo_multivar(0, 1.0>>2.0, 0); _ = a", err: false},
{stmt: "s := 1; a := int(1) + 1.0>>s + int(float(1>>s)); _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> 2.0 >> 3.0 >> 4.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a float = 1 >> 1 >> 1 >> 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + 1.2; _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_float(2); _ = a", err: true},
{stmt: "s := 1; a := 1.0 >> s + foo_float_int(2); _ = a", err: false},
{stmt: "s := 1; a := foo_float_int(1.0>>s) + foo_float_int(2); _ = a", err: true},
{stmt: "s := 1; a := foo_int_float(1>>s) + foo_int_float(2); _ = a", err: false},
{stmt: "s := 1; a := foo_int_int(1>>s) + foo_int_int(2); _ = a", err: false},
{stmt: "s := 1; t := 2.0; a := t + 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; t := 2; a := t + 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; b := 1 >> s; _ = b", err: false},
{stmt: "var a = 1; b := a >> 2.0; _ = b", err: false},
{stmt: "s := 1; var a float; a = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a int = int(1 >> s); _ = a", err: false},
{stmt: "s := 1; var a int = int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; a := 1 >> s; _ = a", err: false},
{stmt: "s := 1; a := 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; a := int(1.0 >> s); _ = a", err: false},
{stmt: "s := 1; var a float = float(1.0 >> s); _ = a", err: true},
{stmt: "s := 1; var a float = 1.0 >> s; _ = a", err: true},
{stmt: "s := 1; var a float = 1 >> s; _ = a", err: true},
{stmt: "s := 1; var a int = 1.0 >> s; _ = a", err: false},
{stmt: "s := 1; var a int = 1 >> s; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2.0; _ = a", err: false},
{stmt: "var a int = 1.0 >> 2; _ = a", err: false},
{stmt: "var a float = 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1 >> 2.0; _ = a", err: false},
{stmt: "a := 1.0 >> 2; _ = a", err: false},
{stmt: "a := 1.0 >> 2.0; _ = a", err: false},
{stmt: "a := 1 >> 2; _ = a", err: false},
{stmt: "a := float(1.0) >> 2; _ = a", err: true},
{stmt: "a := 1 >> float(2.0); _ = a", err: false},
{stmt: "a := 1.0 >> float(2.0); _ = a", err: false},
{stmt: "a := ivec2(1) >> 2; _ = a", err: false},
{stmt: "a := 1 >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := float(1.0) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> ivec3(2); _ = a", err: true},
{stmt: "a := 1 >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> 2; _ = a", err: true},
{stmt: "a := float(1.0) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> float(2.0); _ = a", err: true},
{stmt: "a := vec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> vec3(2); _ = a", err: true},
{stmt: "a := vec3(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec2(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec2(2); _ = a", err: true},
{stmt: "a := vec3(1) >> ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1) >> vec3(2); _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func foo_multivar(x int, y float, z int) int {return x}
func foo_int_int(x int) int {return x}
func foo_float_int(x float) int {return int(x)}
func foo_int_float(x int) float {return float(x)}
func foo_float_float(x float) float {return x}
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
func TestSyntaxOperatorShiftAssign(t *testing.T) {
cases := []struct {
stmt string
err bool
}{
{stmt: "a := 1; a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a <<= 2; _ = a", err: true},
{stmt: "a := 1; a <<= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= 2; _ = a", err: false},
{stmt: "a := 1; a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= 2; _ = a", err: true},
{stmt: "a := float(1.0); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a <<= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a <<= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a <<= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a <<= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a <<= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a <<= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a <<= c; _ = a", err: true},
{stmt: "a := 1; a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= 2.0; _ = a", err: false},
{stmt: "a := float(1.0); a >>= 2; _ = a", err: true},
{stmt: "a := 1; a >>= float(2.0); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= 2; _ = a", err: false},
{stmt: "a := 1; a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := float(1.0); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec2(2); _ = a", err: false},
{stmt: "a := ivec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= ivec3(2); _ = a", err: true},
{stmt: "a := 1; a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= 2; _ = a", err: true},
{stmt: "a := float(1.0); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= float(2.0); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec2(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec2(2); _ = a", err: true},
{stmt: "a := vec3(1); a >>= ivec2(2); _ = a", err: true},
{stmt: "a := ivec2(1); a >>= vec3(2); _ = a", err: true},
{stmt: "const c = 2; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2.0; a := 1; a >>= c; _ = a", err: false},
{stmt: "const c = 2; a := float(1.0); a >>= c; _ = a", err: true},
{stmt: "const c float = 2; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c float = 2.0; a := 1; a >>= c; _ = a", err: true},
{stmt: "const c int = 2; a := ivec2(1); a >>= c; _ = a", err: false},
{stmt: "const c int = 2; a := vec2(1); a >>= c; _ = a", err: true},
}
for _, c := range cases {
_, err := compileToIR([]byte(fmt.Sprintf(`package main
func Fragment(dstPos vec4, srcPos vec2, color vec4) vec4 {
%s
return dstPos
}`, c.stmt)))
if err == nil && c.err {
t.Errorf("%s must return an error but does not", c.stmt)
} else if err != nil && !c.err {
t.Errorf("%s must not return nil but returned %v", c.stmt, err)
}
}
}
// Issue #1971 // Issue #1971
func TestSyntaxOperatorMultiplyAssign(t *testing.T) { func TestSyntaxOperatorMultiplyAssign(t *testing.T) {
cases := []struct { cases := []struct {

View File

@ -18,6 +18,29 @@ import (
"go/constant" "go/constant"
) )
func ResolveUntypedConstsForBitShiftOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
cLhs := lhs
cRhs := rhs
// Right is const -> int
if rhs != nil {
cRhs = constant.ToInt(rhs)
if cRhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
// Left if untyped const -> int
if lhs != nil && lhst.Main == None {
cLhs = constant.ToInt(lhs)
if cLhs.Kind() == constant.Unknown {
return nil, nil, false
}
}
return cLhs, cRhs, true
}
func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) { func ResolveUntypedConstsForBinaryOp(lhs, rhs constant.Value, lhst, rhst Type) (newLhs, newRhs constant.Value, ok bool) {
if lhst.Main == None && rhst.Main == None { if lhst.Main == None && rhst.Main == None {
if lhs.Kind() == rhs.Kind() { if lhs.Kind() == rhs.Kind() {
@ -121,6 +144,16 @@ func TypeFromBinaryOp(op Op, lhst, rhst Type, lhsConst, rhsConst constant.Value)
panic("shaderir: cannot resolve untyped values") panic("shaderir: cannot resolve untyped values")
} }
if op == LeftShift || op == RightShift {
if (lhst.Main == Int || lhst.IsIntVector()) && rhst.Main == Int {
return lhst, true
}
if lhst.IsIntVector() && rhst.IsIntVector() && lhst.VectorElementCount() == rhst.VectorElementCount() {
return lhst, true
}
return Type{}, false
}
if op == AndAnd || op == OrOr { if op == AndAnd || op == OrOr {
if lhst.Main == Bool && rhst.Main == Bool { if lhst.Main == Bool && rhst.Main == Bool {
return Type{Main: Bool}, true return Type{Main: Bool}, true

View File

@ -16,6 +16,7 @@
package shaderir package shaderir
import ( import (
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"sort" "sort"
@ -74,16 +75,17 @@ type Block struct {
} }
type Stmt struct { type Stmt struct {
Type StmtType Type StmtType
Exprs []Expr Exprs []Expr
Blocks []*Block Blocks []*Block
ForVarType Type ForVarType Type
ForVarIndex int ForVarIndex int
ForInit constant.Value ForInit constant.Value
ForEnd constant.Value ForEnd constant.Value
ForOp Op ForOp Op
ForDelta constant.Value ForDelta constant.Value
InitIndex int InitIndex int
IsTypeGuessed bool
} }
type StmtType int type StmtType int
@ -109,6 +111,7 @@ type Expr struct {
Swizzling string Swizzling string
Index int Index int
Op Op Op Op
Ast ast.Expr
} }
type ExprType int type ExprType int